def test_single_assignment(self): """Single weights node assignment""" w1 = spn.Weights(3, num_weights=2) w2 = spn.Weights(0.3, num_weights=4) w3 = spn.Weights([0.4, 0.4, 1.2], num_weights=3) init1 = w1.initialize() # init2 = w2.initialize() # don't initialize for testing init3 = w3.initialize() assign1 = w1.assign([1.0, 3.0]) assign2 = w2.assign(0.5) assign3 = w3.assign(5) with tf.Session() as sess: sess.run([init1, init3]) sess.run([assign1, assign2, assign3]) val1 = sess.run(w1.get_value()) val2 = sess.run(w2.get_value()) val3 = sess.run(w3.get_value()) val1_log = sess.run(tf.exp(w1.get_log_value())) val2_log = sess.run(tf.exp(w2.get_log_value())) val3_log = sess.run(tf.exp(w3.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [0.25, 0.75]) np.testing.assert_array_almost_equal(val2, [0.25, 0.25, 0.25, 0.25]) np.testing.assert_array_almost_equal(val3, [1 / 3, 1 / 3, 1 / 3]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [0.25, 0.75]) np.testing.assert_array_almost_equal(val2_log, [0.25, 0.25, 0.25, 0.25]) np.testing.assert_array_almost_equal(val3_log, [1 / 3, 1 / 3, 1 / 3])
def test_single_initialization(self): """Single weights node initialization""" w1 = spn.Weights(3, num_weights=2) w2 = spn.Weights(0.3, num_weights=4) w3 = spn.Weights([0.4, 0.4, 1.2], num_weights=3) init1 = w1.initialize() init2 = w2.initialize() init3 = w3.initialize() with tf.Session() as sess: sess.run([init1, init2, init3]) val1 = sess.run(w1.get_value()) val2 = sess.run(w2.get_value()) val3 = sess.run(w3.get_value()) val1_log = sess.run(tf.exp(w1.get_log_value())) val2_log = sess.run(tf.exp(w2.get_log_value())) val3_log = sess.run(tf.exp(w3.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [0.5, 0.5]) np.testing.assert_array_almost_equal(val2, [0.25, 0.25, 0.25, 0.25]) np.testing.assert_array_almost_equal(val3, [0.2, 0.2, 0.6]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [0.5, 0.5]) np.testing.assert_array_almost_equal(val2_log, [0.25, 0.25, 0.25, 0.25]) np.testing.assert_array_almost_equal(val3_log, [0.2, 0.2, 0.6])
def test_input_flags(self): """Detection of different types of inputs""" inpt = spn.Input() self.assertFalse(inpt) self.assertFalse(inpt.is_op) self.assertFalse(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.Sum() inpt = spn.Input(n) self.assertTrue(inpt) self.assertTrue(inpt.is_op) self.assertFalse(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.ContVars() inpt = spn.Input(n) self.assertTrue(inpt) self.assertFalse(inpt.is_op) self.assertTrue(inpt.is_var) self.assertFalse(inpt.is_param) n = spn.Weights() inpt = spn.Input(n) self.assertTrue(inpt) self.assertFalse(inpt.is_op) self.assertFalse(inpt.is_var) self.assertTrue(inpt.is_param)
def test_param_learning(self, softplus_scale): spn.conf.argmax_zero = True num_vars = 2 num_components = 2 batch_size = 32 count_init = 100 # Create means and variances means = np.array([[0, 1], [10, 15]]) vars = np.array([[0.25, 0.5], [0.33, 0.67]]) # Sample some data data0 = [ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype( np.float32) for m, v in zip(means[0], vars[0]) ] data1 = [ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype( np.float32) for m, v in zip(means[1], vars[1]) ] data = np.stack([np.concatenate(data0), np.concatenate(data1)], axis=-1) with tf.Graph().as_default() as graph: # Set up SPN gq = spn.GaussianLeaf(num_vars=num_vars, num_components=num_components, initialization_data=data, total_counts_init=count_init, learn_dist_params=True, softplus_scale=softplus_scale) mixture00 = spn.Sum((gq, [0, 1]), name="Mixture00") weights00 = spn.Weights(initializer=tf.initializers.constant( [0.25, 0.75]), num_weights=2) mixture00.set_weights(weights00) mixture01 = spn.Sum((gq, [0, 1]), name="Mixture01") weights01 = spn.Weights(initializer=tf.initializers.constant( [0.75, 0.25]), num_weights=2) mixture01.set_weights(weights01) mixture10 = spn.Sum((gq, [2, 3]), name="Mixture10") weights10 = spn.Weights(initializer=tf.initializers.constant( [2 / 3, 1 / 3]), num_weights=2) mixture10.set_weights(weights10) mixture11 = spn.Sum((gq, [2, 3]), name="Mixture11") weights11 = spn.Weights(initializer=tf.initializers.constant( [1 / 3, 2 / 3]), num_weights=2) mixture11.set_weights(weights11) prod0 = spn.Product(mixture00, mixture10, name="Prod0") prod1 = spn.Product(mixture01, mixture11, name="Prod1") root = spn.Sum(prod0, prod1, name="Root") root_weights = spn.Weights(initializer=tf.initializers.constant( [1 / 2, 1 / 2]), num_weights=2) root.set_weights(root_weights) # Generate new data from slightly shifted Gaussians data0 = np.concatenate([ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype(np.float32) for m, v in zip(means[0] + 0.2, vars[0]) ]) data1 = np.concatenate([ stats.norm(loc=m, scale=np.sqrt(v)).rvs(batch_size // 2).astype(np.float32) for m, v in zip(means[1] + 1.0, vars[1]) ]) # Compute actual log probabilities of roots empirical_means = gq._loc_init empirical_vars = np.square( gq._scale_init) if not softplus_scale else np.square( np.log(np.exp(gq._scale_init) + 1)) log_probs0 = [ stats.norm(loc=m, scale=np.sqrt(v)).logpdf(data0) for m, v in zip(empirical_means[0], empirical_vars[0]) ] log_probs1 = [ stats.norm(loc=m, scale=np.sqrt(v)).logpdf(data1) for m, v in zip(empirical_means[1], empirical_vars[1]) ] # Compute actual log probabilities of mixtures mixture00_val = np.logaddexp(log_probs0[0] + np.log(1 / 4), log_probs0[1] + np.log(3 / 4)) mixture01_val = np.logaddexp(log_probs0[0] + np.log(3 / 4), log_probs0[1] + np.log(1 / 4)) mixture10_val = np.logaddexp(log_probs1[0] + np.log(2 / 3), log_probs1[1] + np.log(1 / 3)) mixture11_val = np.logaddexp(log_probs1[0] + np.log(1 / 3), log_probs1[1] + np.log(2 / 3)) # Compute actual log probabilities of products prod0_val = mixture00_val + mixture10_val prod1_val = mixture01_val + mixture11_val # Compute the index of the max probability at the products layer prod_winner = np.argmax(np.stack([prod0_val, prod1_val], axis=-1), axis=-1) # Compute the indices of the max component per mixture component_winner00 = np.argmax(np.stack( [log_probs0[0] + np.log(1 / 4), log_probs0[1] + np.log(3 / 4)], axis=-1), axis=-1) component_winner01 = np.argmax(np.stack( [log_probs0[0] + np.log(3 / 4), log_probs0[1] + np.log(1 / 4)], axis=-1), axis=-1) component_winner10 = np.argmax(np.stack( [log_probs1[0] + np.log(2 / 3), log_probs1[1] + np.log(1 / 3)], axis=-1), axis=-1) component_winner11 = np.argmax(np.stack( [log_probs1[0] + np.log(1 / 3), log_probs1[1] + np.log(2 / 3)], axis=-1), axis=-1) # Initialize true counts counts_per_component = np.zeros((2, 2)) sum_data_val = np.zeros((2, 2)) sum_data_squared_val = np.zeros((2, 2)) data00 = [] data01 = [] data10 = [] data11 = [] # Compute true counts counts_per_step = np.zeros((batch_size, num_vars, num_components)) for i, (prod_ind, d0, d1) in enumerate(zip(prod_winner, data0, data1)): if prod_ind == 0: # mixture 00 and mixture 10 counts_per_step[i, 0, component_winner00[i]] = 1 counts_per_component[0, component_winner00[i]] += 1 sum_data_val[0, component_winner00[i]] += data0[i] sum_data_squared_val[ 0, component_winner00[i]] += data0[i] * data0[i] (data00 if component_winner00[i] == 0 else data01).append( data0[i]) counts_per_step[i, 1, component_winner10[i]] = 1 counts_per_component[1, component_winner10[i]] += 1 sum_data_val[1, component_winner10[i]] += data1[i] sum_data_squared_val[ 1, component_winner10[i]] += data1[i] * data1[i] (data10 if component_winner10[i] == 0 else data11).append( data1[i]) else: counts_per_step[i, 0, component_winner01[i]] = 1 counts_per_component[0, component_winner01[i]] += 1 sum_data_val[0, component_winner01[i]] += data0[i] sum_data_squared_val[ 0, component_winner01[i]] += data0[i] * data0[i] (data00 if component_winner01[i] == 0 else data01).append( data0[i]) counts_per_step[i, 1, component_winner11[i]] = 1 counts_per_component[1, component_winner11[i]] += 1 sum_data_val[1, component_winner11[i]] += data1[i] sum_data_squared_val[ 1, component_winner11[i]] += data1[i] * data1[i] (data10 if component_winner11[i] == 0 else data11).append( data1[i]) # Setup learning Ops value_inference_type = spn.InferenceType.MARGINAL init_weights = spn.initialize_weights(root) learning = spn.EMLearning( root, log=True, value_inference_type=value_inference_type) reset_accumulators = learning.reset_accumulators() accumulate_updates = learning.accumulate_updates() update_spn = learning.update_spn() train_likelihood = learning.value.values[root] avg_train_likelihood = tf.reduce_mean(train_likelihood) # Setup feed dict and update ops fd = {gq: np.stack([data0, data1], axis=-1)} update_ops = gq._compute_hard_em_update( learning._mpe_path.counts[gq]) with self.test_session(graph=graph) as sess: sess.run(init_weights) # Get log probabilities of Gaussian leaf log_probs = sess.run(learning.value.values[gq], fd) # Get log probabilities of mixtures mixture00_graph, mixture01_graph, mixture10_graph, mixture11_graph = sess.run( [ learning.value.values[mixture00], learning.value.values[mixture01], learning.value.values[mixture10], learning.value.values[mixture11] ], fd) # Get log probabilities of products prod0_graph, prod1_graph = sess.run([ learning.value.values[prod0], learning.value.values[prod1] ], fd) # Get counts for graph counts = sess.run( tf.reduce_sum(learning._mpe_path.counts[gq], axis=0), fd) counts_per_sample = sess.run(learning._mpe_path.counts[gq], fd) accum, sum_data_graph, sum_data_squared_graph = sess.run([ update_ops['accum'], update_ops['sum_data'], update_ops['sum_data_squared'] ], fd) with self.test_session(graph=graph) as sess: sess.run(init_weights) sess.run(reset_accumulators) data_per_component_op = graph.get_tensor_by_name( "EMLearning/GaussianLeaf/DataPerComponent:0") squared_data_per_component_op = graph.get_tensor_by_name( "EMLearning/GaussianLeaf/SquaredDataPerComponent:0") update_vals, data_per_component_out, squared_data_per_component_out = sess.run( [ accumulate_updates, data_per_component_op, squared_data_per_component_op ], fd) # Get likelihood before update lh_before = sess.run(avg_train_likelihood, fd) sess.run(update_spn) # Get likelihood after update lh_after = sess.run(avg_train_likelihood, fd) # Get variables after update total_counts_graph, scale_graph, mean_graph = sess.run([ gq._total_count_variable, gq.scale_variable, gq.loc_variable ]) self.assertAllClose(prod0_val, prod0_graph.ravel()) self.assertAllClose(prod1_val, prod1_graph.ravel()) self.assertAllClose(log_probs[:, 0], log_probs0[0]) self.assertAllClose(log_probs[:, 1], log_probs0[1]) self.assertAllClose(log_probs[:, 2], log_probs1[0]) self.assertAllClose(log_probs[:, 3], log_probs1[1]) self.assertAllClose(mixture00_val, mixture00_graph.ravel()) self.assertAllClose(mixture01_val, mixture01_graph.ravel()) self.assertAllClose(mixture10_val, mixture10_graph.ravel()) self.assertAllClose(mixture11_val, mixture11_graph.ravel()) self.assertAllEqual(counts, counts_per_component.ravel()) self.assertAllEqual(accum, counts_per_component) self.assertAllClose( counts_per_step, counts_per_sample.reshape((batch_size, num_vars, num_components))) self.assertAllClose(sum_data_val, sum_data_graph) self.assertAllClose(sum_data_squared_val, sum_data_squared_graph) self.assertAllClose(total_counts_graph, count_init + counts_per_component) self.assertTrue(np.all(np.not_equal(mean_graph, gq._loc_init))) self.assertTrue(np.all(np.not_equal(scale_graph, gq._scale_init))) mean_new_vals = [] variance_new_vals = [] variance_left, variance_right = [], [] for i, obs in enumerate([data00, data01, data10, data11]): # Note that this does not depend on accumulating anything! # It actually is copied (more-or-less) from # https://github.com/whsu/spn/blob/master/spn/normal_leaf_node.py x = np.asarray(obs).astype(np.float32) n = count_init k = len(obs) if softplus_scale: var_old = np.square( np.log( np.exp(gq._scale_init.astype(np.float32)).ravel()[i] + 1)) else: var_old = np.square(gq._scale_init.astype( np.float32)).ravel()[i] mean = (n * gq._loc_init.astype(np.float32).ravel()[i] + np.sum(obs)) / (n + k) dx = x - gq._loc_init.astype(np.float32).ravel()[i] dm = mean - gq._loc_init.astype(np.float32).ravel()[i] var = (n * var_old + dx.dot(dx)) / (n + k) - dm * dm mean_new_vals.append(mean) variance_new_vals.append(var) variance_left.append((n * var_old + dx.dot(dx)) / (n + k)) variance_right.append(dm * dm) mean_new_vals = np.asarray(mean_new_vals).reshape((2, 2)) variance_new_vals = np.asarray(variance_new_vals).reshape((2, 2)) def assert_non_zero_at_ij_equal(arr, i, j, truth): # Select i-th variable and j-th component arr = arr[:, i, j] self.assertAllClose(arr[arr != 0.0], truth) assert_non_zero_at_ij_equal(data_per_component_out, 0, 0, data00) assert_non_zero_at_ij_equal(data_per_component_out, 0, 1, data01) assert_non_zero_at_ij_equal(data_per_component_out, 1, 0, data10) assert_non_zero_at_ij_equal(data_per_component_out, 1, 1, data11) assert_non_zero_at_ij_equal(squared_data_per_component_out, 0, 0, np.square(data00)) assert_non_zero_at_ij_equal(squared_data_per_component_out, 0, 1, np.square(data01)) assert_non_zero_at_ij_equal(squared_data_per_component_out, 1, 0, np.square(data10)) assert_non_zero_at_ij_equal(squared_data_per_component_out, 1, 1, np.square(data11)) self.assertAllClose(mean_new_vals, mean_graph) # self.assertAllClose(np.asarray(variance_left).reshape((2, 2)), var_graph_left) self.assertAllClose( variance_new_vals, np.square(scale_graph if not softplus_scale else np. log(np.exp(scale_graph) + 1))) self.assertGreater(lh_after, lh_before)
def test_single_initialization(self): """Single weights node initialization""" # Single sum w1 = spn.Weights(tf.initializers.constant(3), num_weights=2) w2 = spn.Weights(tf.initializers.constant(0.3), num_weights=4) w3 = spn.Weights(tf.initializers.constant([0.4, 0.4, 1.2]), num_weights=3) # Multi sums w4 = spn.Weights(tf.initializers.constant(3), num_weights=2, num_sums=2) w5 = spn.Weights(tf.initializers.constant(0.3), num_weights=4, num_sums=3) w6 = spn.Weights(tf.initializers.random_uniform(0.0, 1.0), num_weights=1, num_sums=4) init1 = w1.initialize() init2 = w2.initialize() init3 = w3.initialize() init4 = w4.initialize() init5 = w5.initialize() init6 = w6.initialize() with self.test_session() as sess: sess.run([init1, init2, init3, init4, init5, init6]) val1 = sess.run(w1.get_value()) val2 = sess.run(w2.get_value()) val3 = sess.run(w3.get_value()) val4 = sess.run(w4.get_value()) val5 = sess.run(w5.get_value()) val6 = sess.run(w6.get_value()) val1_log = sess.run(tf.exp(w1.get_log_value())) val2_log = sess.run(tf.exp(w2.get_log_value())) val3_log = sess.run(tf.exp(w3.get_log_value())) val4_log = sess.run(tf.exp(w4.get_log_value())) val5_log = sess.run(tf.exp(w5.get_log_value())) val6_log = sess.run(tf.exp(w6.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val4.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val5.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val6.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [[0.5, 0.5]]) np.testing.assert_array_almost_equal(val2, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val3, [[0.2, 0.2, 0.6]]) np.testing.assert_array_almost_equal(val4, [[0.5, 0.5], [0.5, 0.5]]) np.testing.assert_array_almost_equal( val5, [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val6, [[1.0], [1.0], [1.0], [1.0]]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [[0.5, 0.5]]) np.testing.assert_array_almost_equal(val2_log, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val3_log, [[0.2, 0.2, 0.6]]) np.testing.assert_array_almost_equal(val4_log, [[0.5, 0.5], [0.5, 0.5]]) np.testing.assert_array_almost_equal( val5_log, [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val6_log, [[1.0], [1.0], [1.0], [1.0]])