def test_compute_valid(self): """Calculating validity of Sum""" # Without IVs v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) s1 = spn.Sum((v12, [0, 1, 2, 3])) s2 = spn.Sum((v12, [0, 1, 2, 4])) s3 = spn.Sum((v12, [0, 1, 2, 3]), (v34, 0)) p1 = spn.Product((v12, [0, 5]), (v34, 0)) p2 = spn.Product((v12, [1, 6]), (v34, 0)) p3 = spn.Product((v12, [1, 6]), (v34, 1)) s4 = spn.Sum(p1, p2) s5 = spn.Sum(p1, p3) self.assertTrue(v12.is_valid()) self.assertTrue(v34.is_valid()) self.assertTrue(s1.is_valid()) self.assertFalse(s2.is_valid()) self.assertFalse(s3.is_valid()) self.assertTrue(s4.is_valid()) self.assertFalse(s5.is_valid()) # With IVS s6 = spn.Sum(p1, p2) s6.generate_ivs() self.assertTrue(s6.is_valid()) s7 = spn.Sum(p1, p2) s7.set_ivs(spn.ContVars(num_vars=2)) self.assertFalse(s7.is_valid()) s8 = spn.Sum(p1, p2) s8.set_ivs(spn.IVs(num_vars=2, num_vals=2)) with self.assertRaises(spn.StructureError): s8.is_valid() s9 = spn.Sum(p1, p2) s9.set_ivs((v12, [0, 3])) self.assertTrue(s9.is_valid())
def test_compute_valid(self): """Calculating validity of Product""" v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4) v34 = spn.RawLeaf(num_vars=2) p1 = spn.Product((v12, [0, 5])) p2 = spn.Product((v12, [0, 3])) p3 = spn.Product((v12, [0, 5]), v34) p4 = spn.Product((v12, [0, 3]), v34) p5 = spn.Product((v12, [0, 5]), v34, (v12, 2)) self.assertTrue(p1.is_valid()) self.assertFalse(p2.is_valid()) self.assertTrue(p3.is_valid()) self.assertFalse(p4.is_valid()) self.assertFalse(p5.is_valid())
def test_compute_mpe_path(self): v12 = spn.IVs(num_vars=2, num_vals=4) v34 = spn.ContVars(num_vars=2) v5 = spn.ContVars(num_vars=1) p = spn.Product((v12, [0, 5]), v34, (v12, [3]), v5) counts = tf.placeholder(tf.float32, shape=(None, 1)) op = p._compute_mpe_path(tf.identity(counts), v12.get_value(), v34.get_value(), v12.get_value(), v5.get_value()) feed = [[0], [1], [2]] with tf.Session() as sess: out = sess.run(op, feed_dict={counts: feed}) np.testing.assert_array_almost_equal( out[0], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 1., 0., 0.], [2., 0., 0., 0., 0., 2., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[1], np.array([[0., 0.], [1., 1.], [2., 2.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[2], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 2., 0., 0., 0., 0.]], dtype=np.float32)) np.testing.assert_array_almost_equal( out[3], np.array([[0.], [1.], [2.]], dtype=np.float32))
def poon_single(inputs, num_vals, num_mixtures, num_subsets, inf_type, log=False, output=None): # Build a POON-like network with single-op nodes subsets = [[ spn.Sum((inputs, list(range(i * num_vals, (i + 1) * num_vals)))) for _ in range(num_mixtures) ] for i in range(num_subsets)] products = [ spn.Product(*list(inp)) for inp in list(product(*[s for s in subsets])) ] root = spn.Sum(*products, name="root") # Generate dense SPN and all weights in the network spn.generate_weights(root) # Generate path ops based on inf_type and log if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_ops = [ mpe_path_gen.counts[inp] for inp in (inputs if isinstance(inputs, list) else [inputs]) ] return root, spn.initialize_weights(root), path_ops
def test_group_initialization(self): """Group initialization of weights nodes""" v1 = spn.IVs(num_vars=1, num_vals=2) v2 = spn.IVs(num_vars=1, num_vals=4) s1 = spn.Sum(v1) s1.generate_weights([0.2, 0.3]) s2 = spn.Sum(v2) s2.generate_weights(5) p = spn.Product(s1, s2) init = spn.initialize_weights(p) with tf.Session() as sess: sess.run([init]) val1 = sess.run(s1.weights.node.get_value()) val2 = sess.run(s2.weights.node.get_value()) val1_log = sess.run(tf.exp(s1.weights.node.get_log_value())) val2_log = sess.run(tf.exp(s2.weights.node.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [0.4, 0.6]) np.testing.assert_array_almost_equal(val2, [0.25, 0.25, 0.25, 0.25]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [0.4, 0.6]) np.testing.assert_array_almost_equal(val2_log, [0.25, 0.25, 0.25, 0.25])
def test_group_initialization(self): """Group initialization of weights nodes""" v1 = spn.IVs(num_vars=1, num_vals=2) v2 = spn.IVs(num_vars=1, num_vals=4) v3 = spn.IVs(num_vars=1, num_vals=2) v4 = spn.IVs(num_vars=1, num_vals=2) # Sum s1 = spn.Sum(v1) s1.generate_weights(tf.initializers.constant([0.2, 0.3])) s2 = spn.Sum(v2) s2.generate_weights(tf.initializers.constant(5)) # ParSums s3 = spn.ParSums(*[v3, v4], num_sums=2) s3.generate_weights( tf.initializers.constant([0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1])) s4 = spn.ParSums(*[v1, v2, v3, v4], num_sums=3) s4.generate_weights(tf.initializers.constant(2.0)) # Product p = spn.Product(s1, s2, s3, s4) init = spn.initialize_weights(p) with self.test_session() as sess: sess.run([init]) val1 = sess.run(s1.weights.node.get_value()) val2 = sess.run(s2.weights.node.get_value()) val3 = sess.run(s3.weights.node.get_value()) val4 = sess.run(s4.weights.node.get_value()) val1_log = sess.run(tf.exp(s1.weights.node.get_log_value())) val2_log = sess.run(tf.exp(s2.weights.node.get_log_value())) val3_log = sess.run(tf.exp(s3.weights.node.get_log_value())) val4_log = sess.run(tf.exp(s4.weights.node.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val4.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [[0.4, 0.6]]) np.testing.assert_array_almost_equal(val2, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal( val3, [[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]) np.testing.assert_array_almost_equal( val4, [[0.1] * 10, [0.1] * 10, [0.1] * 10]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val4_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [[0.4, 0.6]]) np.testing.assert_array_almost_equal(val2_log, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal( val3, [[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]) np.testing.assert_array_almost_equal( val4, [[0.1] * 10, [0.1] * 10, [0.1] * 10])
def test_is_valid_false(self): """Checking validity of the SPN""" # Create graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [2, 3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") # Test self.assertTrue(v12.is_valid()) self.assertTrue(v34.is_valid()) self.assertTrue(s1.is_valid()) self.assertTrue(s2.is_valid()) self.assertTrue(p1.is_valid()) self.assertTrue(p3.is_valid()) self.assertTrue(p4.is_valid()) self.assertTrue(n1.is_valid()) self.assertFalse(p2.is_valid()) self.assertFalse(n2.is_valid()) self.assertFalse(s3.is_valid()) self.assertFalse(s4.is_valid()) self.assertFalse(p5.is_valid()) self.assertFalse(p6.is_valid())
def test_get_scope(self): """Computing the scope of nodes of the SPN graph""" # Create graph v12 = spn.IVs(num_vars=2, num_vals=4, name="V12") v34 = spn.ContVars(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p6.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)])
def product(inputs, num_inputs, num_input_cols, num_prods, inf_type, indices=None, log=False, output=None): p = [] for inps, n_inp_cols in zip(inputs, num_input_cols): num_inputs = len(inps) # Create permuted indices based on number and size of inputs inds = map(int, np.arange(n_inp_cols)) permuted_inds = list(product(inds, repeat=num_inputs)) permuted_inds_list = [list(elem) for elem in permuted_inds] permuted_inds_list_of_list = [] for elem in permuted_inds_list: permuted_inds_list_of_list.append( [elem[i:i + 1] for i in range(0, len(elem), 1)]) # Create inputs list by combining inputs and indices permuted_inputs = [] for indices in permuted_inds_list_of_list: permuted_inputs.append([tuple(i) for i in zip(inps, indices)]) # Generate 'n_prods' Product nodes, connecting each to its inputs for perm_inps in permuted_inputs: p = p + [spn.Product(*perm_inps)] # Connect all product nodes to a single root Sum node and generate its # weights root = spn.Sum(*p) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_ops = [ mpe_path_gen.counts[inp] for inp in list(chain.from_iterable(inputs)) ] return spn.initialize_weights(root), path_ops
def poon_single(inputs, num_vals, num_mixtures, num_subsets, inf_type, log=False, output=None): # Build a POON-like network with single-op nodes subsets = [[spn.Sum((inputs, list(range(i*num_vals, (i+1)*num_vals)))) for _ in range(num_mixtures)] for i in range(num_subsets)] products = [spn.Product(*list(inp)) for inp in list(product(*[s for s in subsets]))] root = spn.Sum(*products, name="root") # Generate dense SPN and all weights in the network spn.generate_weights(root) # Generate value ops based on inf_type and log if log: value_op = root.get_log_value(inference_type=inf_type) else: value_op = root.get_value(inference_type=inf_type) return root, spn.initialize_weights(root), value_op
def test_compute_log_mpe_path(self): v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4) v34 = spn.RawLeaf(num_vars=2) v5 = spn.RawLeaf(num_vars=1) p = spn.Product((v12, [0, 5]), v34, (v12, [3]), v5) counts = tf.placeholder(tf.float32, shape=(None, 1)) op = p._compute_log_mpe_path(tf.identity(counts), v12.get_value(), v34.get_value(), v12.get_value(), v5.get_value()) feed = [[0], [1], [2]] with self.test_session() as sess: out = sess.run(op, feed_dict={counts: feed}) self.assertAllClose( out[0], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 1., 0., 0.], [2., 0., 0., 0., 0., 2., 0., 0.]], dtype=np.float32)) self.assertAllClose( out[1], np.array([[0., 0.], [1., 1.], [2., 2.]], dtype=np.float32)) self.assertAllClose( out[2], np.array([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0.], [0., 0., 0., 2., 0., 0., 0., 0.]], dtype=np.float32)) self.assertAllClose( out[3], np.array([[0.], [1.], [2.]], dtype=np.float32))
def test(inputs, feed, output): with self.subTest(inputs=inputs, feed=feed): n = spn.Product(*inputs) op = n.get_value(spn.InferenceType.MARGINAL) op_log = n.get_log_value(spn.InferenceType.MARGINAL) op_mpe = n.get_value(spn.InferenceType.MPE) op_log_mpe = n.get_log_value(spn.InferenceType.MPE) with self.test_session() as sess: out = sess.run(op, feed_dict=feed) out_log = sess.run(tf.exp(op_log), feed_dict=feed) out_mpe = sess.run(op_mpe, feed_dict=feed) out_log_mpe = sess.run(tf.exp(op_log_mpe), feed_dict=feed) self.assertAllClose( out, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) self.assertAllClose( out_log, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) self.assertAllClose( out_mpe, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype())) self.assertAllClose( out_log_mpe, np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
def test_comput_scope(self): """Calculating scope of Product""" # Create a graph v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12") v34 = spn.RawLeaf(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s1.generate_latent_indicators() s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P1") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") p4 = spn.Product((n1, [0]), (n1, [1]), name="P4") p5 = spn.Product((n2, [0]), (n1, [2]), name="P5") s3 = spn.Sum(p4, n2, name="S3") p6 = spn.Product(s3, (n1, [2]), name="P6") s4 = spn.Sum(p5, p6, name="S4") s4.generate_latent_indicators() # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(p5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(p6.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s4.latent_indicators.node, 0)])
def test_param_learning(self, softplus_scale): spn.conf.argmax_zero = True num_vars = 2 num_components = 2 batch_size = 32 count_init = 100 # Create means and variances means = np.array([[0, 1], [10, 15]]) vars = np.array([[0.25, 0.5], [0.33, 0.67]]) # Sample some data data0 = [ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype( np.float32) for m, v in zip(means[0], vars[0]) ] data1 = [ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype( np.float32) for m, v in zip(means[1], vars[1]) ] data = np.stack([np.concatenate(data0), np.concatenate(data1)], axis=-1) with tf.Graph().as_default() as graph: # Set up SPN gq = spn.GaussianLeaf(num_vars=num_vars, num_components=num_components, initialization_data=data, total_counts_init=count_init, learn_dist_params=True, softplus_scale=softplus_scale) mixture00 = spn.Sum((gq, [0, 1]), name="Mixture00") weights00 = spn.Weights(initializer=tf.initializers.constant( [0.25, 0.75]), num_weights=2) mixture00.set_weights(weights00) mixture01 = spn.Sum((gq, [0, 1]), name="Mixture01") weights01 = spn.Weights(initializer=tf.initializers.constant( [0.75, 0.25]), num_weights=2) mixture01.set_weights(weights01) mixture10 = spn.Sum((gq, [2, 3]), name="Mixture10") weights10 = spn.Weights(initializer=tf.initializers.constant( [2 / 3, 1 / 3]), num_weights=2) mixture10.set_weights(weights10) mixture11 = spn.Sum((gq, [2, 3]), name="Mixture11") weights11 = spn.Weights(initializer=tf.initializers.constant( [1 / 3, 2 / 3]), num_weights=2) mixture11.set_weights(weights11) prod0 = spn.Product(mixture00, mixture10, name="Prod0") prod1 = spn.Product(mixture01, mixture11, name="Prod1") root = spn.Sum(prod0, prod1, name="Root") root_weights = spn.Weights(initializer=tf.initializers.constant( [1 / 2, 1 / 2]), num_weights=2) root.set_weights(root_weights) # Generate new data from slightly shifted Gaussians data0 = np.concatenate([ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype(np.float32) for m, v in zip(means[0] + 0.2, vars[0]) ]) data1 = np.concatenate([ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype(np.float32) for m, v in zip(means[1] + 1.0, vars[1]) ]) # Compute actual log probabilities of roots empirical_means = gq._loc_init empirical_vars = np.square( gq._scale_init) if not softplus_scale else np.square( np.log(np.exp(gq._scale_init) + 1)) log_probs0 = [ stats.norm(loc=m, scale=np.sqrt(v)).logpdf(data0) for m, v in zip(empirical_means[0], empirical_vars[0]) ] log_probs1 = [ stats.norm(loc=m, scale=np.sqrt(v)).logpdf(data1) for m, v in zip(empirical_means[1], empirical_vars[1]) ] # Compute actual log probabilities of mixtures mixture00_val = np.logaddexp(log_probs0[0] + np.log(1 / 4), log_probs0[1] + np.log(3 / 4)) mixture01_val = np.logaddexp(log_probs0[0] + np.log(3 / 4), log_probs0[1] + np.log(1 / 4)) mixture10_val = np.logaddexp(log_probs1[0] + np.log(2 / 3), log_probs1[1] + np.log(1 / 3)) mixture11_val = np.logaddexp(log_probs1[0] + np.log(1 / 3), log_probs1[1] + np.log(2 / 3)) # Compute actual log probabilities of products prod0_val = mixture00_val + mixture10_val prod1_val = mixture01_val + mixture11_val # Compute the index of the max probability at the products layer prod_winner = np.argmax(np.stack([prod0_val, prod1_val], axis=-1), axis=-1) # Compute the indices of the max component per mixture component_winner00 = np.argmax(np.stack( [log_probs0[0] + np.log(1 / 4), log_probs0[1] + np.log(3 / 4)], axis=-1), axis=-1) component_winner01 = np.argmax(np.stack( [log_probs0[0] + np.log(3 / 4), log_probs0[1] + np.log(1 / 4)], axis=-1), axis=-1) component_winner10 = np.argmax(np.stack( [log_probs1[0] + np.log(2 / 3), log_probs1[1] + np.log(1 / 3)], axis=-1), axis=-1) component_winner11 = np.argmax(np.stack( [log_probs1[0] + np.log(1 / 3), log_probs1[1] + np.log(2 / 3)], axis=-1), axis=-1) # Initialize true counts counts_per_component = np.zeros((2, 2)) sum_data_val = np.zeros((2, 2)) sum_data_squared_val = np.zeros((2, 2)) data00 = [] data01 = [] data10 = [] data11 = [] # Compute true counts counts_per_step = np.zeros((batch_size, num_vars, num_components)) for i, (prod_ind, d0, d1) in enumerate(zip(prod_winner, data0, data1)): if prod_ind == 0: # mixture 00 and mixture 10 counts_per_step[i, 0, component_winner00[i]] = 1 counts_per_component[0, component_winner00[i]] += 1 sum_data_val[0, component_winner00[i]] += data0[i] sum_data_squared_val[ 0, component_winner00[i]] += data0[i] * data0[i] (data00 if component_winner00[i] == 0 else data01).append( data0[i]) counts_per_step[i, 1, component_winner10[i]] = 1 counts_per_component[1, component_winner10[i]] += 1 sum_data_val[1, component_winner10[i]] += data1[i] sum_data_squared_val[ 1, component_winner10[i]] += data1[i] * data1[i] (data10 if component_winner10[i] == 0 else data11).append( data1[i]) else: counts_per_step[i, 0, component_winner01[i]] = 1 counts_per_component[0, component_winner01[i]] += 1 sum_data_val[0, component_winner01[i]] += data0[i] sum_data_squared_val[ 0, component_winner01[i]] += data0[i] * data0[i] (data00 if component_winner01[i] == 0 else data01).append( data0[i]) counts_per_step[i, 1, component_winner11[i]] = 1 counts_per_component[1, component_winner11[i]] += 1 sum_data_val[1, component_winner11[i]] += data1[i] sum_data_squared_val[ 1, component_winner11[i]] += data1[i] * data1[i] (data10 if component_winner11[i] == 0 else data11).append( data1[i]) # Setup learning Ops value_inference_type = spn.InferenceType.MARGINAL init_weights = spn.initialize_weights(root) learning = spn.EMLearning( root, log=True, value_inference_type=value_inference_type) reset_accumulators = learning.reset_accumulators() accumulate_updates = learning.accumulate_updates() update_spn = learning.update_spn() train_likelihood = learning.value.values[root] avg_train_likelihood = tf.reduce_mean(train_likelihood) # Setup feed dict and update ops fd = {gq: np.stack([data0, data1], axis=-1)} update_ops = gq._compute_hard_em_update( learning._mpe_path.counts[gq]) with self.test_session(graph=graph) as sess: sess.run(init_weights) # Get log probabilities of Gaussian leaf log_probs = sess.run(learning.value.values[gq], fd) # Get log probabilities of mixtures mixture00_graph, mixture01_graph, mixture10_graph, mixture11_graph = sess.run( [ learning.value.values[mixture00], learning.value.values[mixture01], learning.value.values[mixture10], learning.value.values[mixture11] ], fd) # Get log probabilities of products prod0_graph, prod1_graph = sess.run([ learning.value.values[prod0], learning.value.values[prod1] ], fd) # Get counts for graph counts = sess.run( tf.reduce_sum(learning._mpe_path.counts[gq], axis=0), fd) counts_per_sample = sess.run(learning._mpe_path.counts[gq], fd) accum, sum_data_graph, sum_data_squared_graph = sess.run([ update_ops['accum'], update_ops['sum_data'], update_ops['sum_data_squared'] ], fd) with self.test_session(graph=graph) as sess: sess.run(init_weights) sess.run(reset_accumulators) data_per_component_op = graph.get_tensor_by_name( "EMLearning/GaussianLeaf/DataPerComponent:0") squared_data_per_component_op = graph.get_tensor_by_name( "EMLearning/GaussianLeaf/SquaredDataPerComponent:0") update_vals, data_per_component_out, squared_data_per_component_out = sess.run( [ accumulate_updates, data_per_component_op, squared_data_per_component_op ], fd) # Get likelihood before update lh_before = sess.run(avg_train_likelihood, fd) sess.run(update_spn) # Get likelihood after update lh_after = sess.run(avg_train_likelihood, fd) # Get variables after update total_counts_graph, scale_graph, mean_graph = sess.run([ gq._total_count_variable, gq.scale_variable, gq.loc_variable ]) self.assertAllClose(prod0_val, prod0_graph.ravel()) self.assertAllClose(prod1_val, prod1_graph.ravel()) self.assertAllClose(log_probs[:, 0], log_probs0[0]) self.assertAllClose(log_probs[:, 1], log_probs0[1]) self.assertAllClose(log_probs[:, 2], log_probs1[0]) self.assertAllClose(log_probs[:, 3], log_probs1[1]) self.assertAllClose(mixture00_val, mixture00_graph.ravel()) self.assertAllClose(mixture01_val, mixture01_graph.ravel()) self.assertAllClose(mixture10_val, mixture10_graph.ravel()) self.assertAllClose(mixture11_val, mixture11_graph.ravel()) self.assertAllEqual(counts, counts_per_component.ravel()) self.assertAllEqual(accum, counts_per_component) self.assertAllClose( counts_per_step, counts_per_sample.reshape((batch_size, num_vars, num_components))) self.assertAllClose(sum_data_val, sum_data_graph) self.assertAllClose(sum_data_squared_val, sum_data_squared_graph) self.assertAllClose(total_counts_graph, count_init + counts_per_component) self.assertTrue(np.all(np.not_equal(mean_graph, gq._loc_init))) self.assertTrue(np.all(np.not_equal(scale_graph, gq._scale_init))) mean_new_vals = [] variance_new_vals = [] variance_left, variance_right = [], [] for i, obs in enumerate([data00, data01, data10, data11]): # Note that this does not depend on accumulating anything! # It actually is copied (more-or-less) from # https://github.com/whsu/spn/blob/master/spn/normal_leaf_node.py x = np.asarray(obs).astype(np.float32) n = count_init k = len(obs) if softplus_scale: var_old = np.square( np.log( np.exp(gq._scale_init.astype(np.float32)).ravel()[i] + 1)) else: var_old = np.square(gq._scale_init.astype( np.float32)).ravel()[i] mean = (n * gq._loc_init.astype(np.float32).ravel()[i] + np.sum(obs)) / (n + k) dx = x - gq._loc_init.astype(np.float32).ravel()[i] dm = mean - gq._loc_init.astype(np.float32).ravel()[i] var = (n * var_old + dx.dot(dx)) / (n + k) - dm * dm mean_new_vals.append(mean) variance_new_vals.append(var) variance_left.append((n * var_old + dx.dot(dx)) / (n + k)) variance_right.append(dm * dm) mean_new_vals = np.asarray(mean_new_vals).reshape((2, 2)) variance_new_vals = np.asarray(variance_new_vals).reshape((2, 2)) def assert_non_zero_at_ij_equal(arr, i, j, truth): # Select i-th variable and j-th component arr = arr[:, i, j] self.assertAllClose(arr[arr != 0.0], truth) assert_non_zero_at_ij_equal(data_per_component_out, 0, 0, data00) assert_non_zero_at_ij_equal(data_per_component_out, 0, 1, data01) assert_non_zero_at_ij_equal(data_per_component_out, 1, 0, data10) assert_non_zero_at_ij_equal(data_per_component_out, 1, 1, data11) assert_non_zero_at_ij_equal(squared_data_per_component_out, 0, 0, np.square(data00)) assert_non_zero_at_ij_equal(squared_data_per_component_out, 0, 1, np.square(data01)) assert_non_zero_at_ij_equal(squared_data_per_component_out, 1, 0, np.square(data10)) assert_non_zero_at_ij_equal(squared_data_per_component_out, 1, 1, np.square(data11)) self.assertAllClose(mean_new_vals, mean_graph) # self.assertAllClose(np.asarray(variance_left).reshape((2, 2)), var_graph_left) self.assertAllClose( variance_new_vals, np.square(scale_graph if not softplus_scale else np. log(np.exp(scale_graph) + 1))) self.assertGreater(lh_after, lh_before)
def test_compute_valid(self): """Calculating validity of Sums""" # Without IndicatorLeaf v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4) v34 = spn.RawLeaf(num_vars=2) s1 = spn.SumsLayer((v12, [0, 1, 2, 3]), (v12, [0, 1, 2, 3]), (v12, [0, 1, 2, 3]), num_or_size_sums=3) self.assertTrue(s1.is_valid()) s2 = spn.SumsLayer((v12, [0, 1, 2, 4]), name="S2") s2b = spn.SumsLayer((v12, [0, 1, 2, 4]), num_or_size_sums=[3, 1], name="S2b") self.assertTrue(s2b.is_valid()) self.assertFalse(s2.is_valid()) s3 = spn.SumsLayer((v12, [0, 1, 2, 3]), (v34, 0), (v12, [0, 1, 2, 3]), (v34, 0), num_or_size_sums=2) s3b = spn.SumsLayer((v12, [0, 1, 2, 3]), (v34, 0), (v12, [0, 1, 2, 3]), (v34, 0), num_or_size_sums=[4, 1, 4, 1]) s3c = spn.SumsLayer((v12, [0, 1, 2, 3]), (v34, 0), (v12, [0, 1, 2, 3]), (v34, 0), num_or_size_sums=[4, 1, 5]) self.assertFalse(s3.is_valid()) self.assertTrue(s3b.is_valid()) self.assertFalse(s3c.is_valid()) p1 = spn.Product((v12, [0, 5]), (v34, 0)) p2 = spn.Product((v12, [1, 6]), (v34, 0)) p3 = spn.Product((v12, [1, 6]), (v34, 1)) s4 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s5 = spn.SumsLayer(p1, p3, p1, p3, p1, p3, num_or_size_sums=3) s6 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=[2, 1]) s7 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=[1, 2]) s8 = spn.SumsLayer(p1, p2, p3, p2, p1, num_or_size_sums=[2, 1, 2]) self.assertTrue(s4.is_valid()) self.assertFalse(s5.is_valid()) # p1 and p3 different scopes self.assertTrue(s6.is_valid()) self.assertFalse(s7.is_valid()) # p2 and p3 different scopes self.assertTrue(s8.is_valid()) # With IVS s6 = spn.SumsLayer(p1, p2, p1, p2, p1, p2, num_or_size_sums=3) s6.generate_latent_indicators() self.assertTrue(s6.is_valid()) s7 = spn.SumsLayer(p1, p2, num_or_size_sums=1) s7.set_latent_indicators(spn.RawLeaf(num_vars=2)) self.assertFalse(s7.is_valid()) s7 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=3) s7.set_latent_indicators(spn.RawLeaf(num_vars=3)) self.assertTrue(s7.is_valid()) s7 = spn.SumsLayer(p1, p2, p3, num_or_size_sums=[2, 1]) s7.set_latent_indicators(spn.RawLeaf(num_vars=3)) self.assertFalse(s7.is_valid()) s8 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s8.set_latent_indicators(spn.IndicatorLeaf(num_vars=3, num_vals=2)) with self.assertRaises(spn.StructureError): s8.is_valid() s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_latent_indicators(spn.RawLeaf(num_vars=2)) with self.assertRaises(spn.StructureError): s9.is_valid() s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_latent_indicators(spn.RawLeaf(num_vars=3)) with self.assertRaises(spn.StructureError): s9.is_valid() s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s9.set_latent_indicators(spn.IndicatorLeaf(num_vars=1, num_vals=4)) self.assertTrue(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_latent_indicators(spn.IndicatorLeaf(num_vars=1, num_vals=4)) self.assertTrue(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 3]) s9.set_latent_indicators(spn.IndicatorLeaf(num_vars=2, num_vals=2)) self.assertFalse(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s9.set_latent_indicators(spn.IndicatorLeaf(num_vars=2, num_vals=2)) self.assertTrue(s9.is_valid()) s9 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 2, 1]) s9.set_latent_indicators(spn.IndicatorLeaf(num_vars=2, num_vals=2)) self.assertFalse(s9.is_valid()) s10 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=2) s10.set_latent_indicators((v12, [0, 3, 5, 7])) self.assertTrue(s10.is_valid()) s10 = spn.SumsLayer(p1, p2, p1, p2, num_or_size_sums=[1, 2, 1]) s10.set_latent_indicators((v12, [0, 3, 5, 7])) self.assertFalse(s10.is_valid())
def test_comput_scope(self): """Calculating scope of PermuteProducts""" # Create graph v12 = spn.IndicatorLeaf(num_vars=2, num_vals=4, name="V12") v34 = spn.RawLeaf(num_vars=2, name="V34") s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1") s1.generate_latent_indicators() s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2") p1 = spn.Product((v12, [0, 7]), name="P1") p2 = spn.Product((v12, [3, 4]), name="P2") p3 = spn.Product(v34, name="P3") n1 = spn.Concat(s1, s2, p3, name="N1") n2 = spn.Concat(p1, p2, name="N2") pp1 = spn.PermuteProducts(n1, n2, name="PP1") # num_prods = 6 pp2 = spn.PermuteProducts((n1, [0, 1]), (n2, [0]), name="PP2") # num_prods = 2 pp3 = spn.PermuteProducts(n2, p3, name="PP3") # num_prods = 2 pp4 = spn.PermuteProducts(p2, p3, name="PP4") # num_prods = 1 pp5 = spn.PermuteProducts((n2, [0, 1]), name="PP5") # num_prods = 1 pp6 = spn.PermuteProducts(p3, name="PP6") # num_prods = 1 n3 = spn.Concat((pp1, [0, 2, 3]), pp2, pp4, name="N3") s3 = spn.Sum((pp1, [0, 2, 4]), (pp1, [1, 3, 5]), pp2, pp3, (pp4, 0), pp5, pp6, name="S3") s3.generate_latent_indicators() n4 = spn.Concat((pp3, [0, 1]), pp5, (pp6, 0), name="N4") pp7 = spn.PermuteProducts(n3, s3, n4, name="PP7") # num_prods = 24 pp8 = spn.PermuteProducts(n3, name="PP8") # num_prods = 1 pp9 = spn.PermuteProducts((n4, [0, 1, 2, 3]), name="PP9") # num_prods = 1 # Test self.assertListEqual(v12.get_scope(), [spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 0), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1), spn.Scope(v12, 1)]) self.assertListEqual(v34.get_scope(), [spn.Scope(v34, 0), spn.Scope(v34, 1)]) self.assertListEqual(s1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0)]) self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)]) self.assertListEqual(p1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(p3.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(pp1.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp2.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(pp3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp5.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1)]) self.assertListEqual(pp6.get_scope(), [spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(n3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(s3.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s3.latent_indicators.node, 0)]) self.assertListEqual(n4.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1), spn.Scope(v12, 0) | spn.Scope(v12, 1), spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp7.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(s3.latent_indicators.node, 0)] * 24) self.assertListEqual(pp8.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.latent_indicators.node, 0) | spn.Scope(v34, 0) | spn.Scope(v34, 1)]) self.assertListEqual(pp9.get_scope(), [spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0) | spn.Scope(v34, 1)])