def test_group_initialization(self): """Group initialization of weights nodes""" v1 = spn.IndicatorLeaf(num_vars=1, num_vals=2) v2 = spn.IndicatorLeaf(num_vars=1, num_vals=4) v3 = spn.IndicatorLeaf(num_vars=1, num_vals=2) v4 = spn.IndicatorLeaf(num_vars=1, num_vals=2) # Sum s1 = spn.Sum(v1) s1.generate_weights(tf.initializers.constant([0.2, 0.3])) s2 = spn.Sum(v2) s2.generate_weights(tf.initializers.constant(5)) # ParallelSums s3 = spn.ParallelSums(*[v3, v4], num_sums=2) s3.generate_weights( tf.initializers.constant([0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1])) s4 = spn.ParallelSums(*[v1, v2, v3, v4], num_sums=3) s4.generate_weights(tf.initializers.constant(2.0)) # Product p = spn.Product(s1, s2, s3, s4) init = spn.initialize_weights(p) with self.test_session() as sess: sess.run([init]) val1 = sess.run(s1.weights.node.get_value()) val2 = sess.run(s2.weights.node.get_value()) val3 = sess.run(s3.weights.node.get_value()) val4 = sess.run(s4.weights.node.get_value()) val1_log = sess.run(tf.exp(s1.weights.node.get_log_value())) val2_log = sess.run(tf.exp(s2.weights.node.get_log_value())) val3_log = sess.run(tf.exp(s3.weights.node.get_log_value())) val4_log = sess.run(tf.exp(s4.weights.node.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val4.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [[0.4, 0.6]]) np.testing.assert_array_almost_equal(val2, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal( val3, [[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]) np.testing.assert_array_almost_equal( val4, [[0.1] * 10, [0.1] * 10, [0.1] * 10]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val4_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [[0.4, 0.6]]) np.testing.assert_array_almost_equal(val2_log, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal( val3, [[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]]) np.testing.assert_array_almost_equal( val4, [[0.1] * 10, [0.1] * 10, [0.1] * 10])
def test_stochastic_argmax(self, argmax_zero): spn.conf.argmax_zero = argmax_zero N = 100000 s = spn.Sum() x = tf.constant( np.repeat([[0, 1, 1, 0, 1], [1, 0, 0, 1, 0]], repeats=N, axis=0)) argmax_op = tf.squeeze(s._reduce_argmax(tf.expand_dims(x, 1))) with self.test_session() as sess: argmax = sess.run(argmax_op) hist_first, _ = np.histogram(argmax[:N], bins=list(range(6))) hist_second, _ = np.histogram(argmax[N:], bins=list(range(6))) if argmax_zero: self.assertEqual(hist_first[1], N) self.assertEqual(hist_second[0], N) else: [self.assertLess(hist_first[i], N / 3 + N / 6) for i in [1, 2, 4]] [ self.assertGreater(hist_first[i], N / 3 - N / 6) for i in [1, 2, 4] ] [self.assertLess(hist_second[i], N / 2 + N / 6) for i in [0, 3]] [self.assertGreater(hist_second[i], N / 2 - N / 6) for i in [0, 3]]
def test_sampling(self, sample_prob): N = 100000 x = tf.expand_dims(tf.constant(np.repeat( [[1, 2, 2, 1, 3], [3, 1, 1, 2, 1]], repeats=N, axis=0), dtype=tf.float32), axis=1) s = spn.Sum() probs = [[1 / 9, 2 / 9, 2 / 9, 1 / 9, 3 / 9], [3 / 8, 1 / 8, 1 / 8, 2 / 8, 1 / 8]] N_sampled = N * sample_prob N_argmax = N - N_sampled sample_op = tf.squeeze( s._reduce_sample_log(tf.log(x), sample_prob=sample_prob)) with self.test_session() as sess: sample_out = sess.run(sample_op) for samples, prob, max_ind in zip( np.split(sample_out, indices_or_sections=2), probs, [4, 0]): hist, _ = np.histogram(samples, bins=list(range(6))) for h, p, i in zip(hist, prob, range(5)): if i == max_ind: estimate = N_argmax + N_sampled * p else: estimate = N_sampled * p self.assertLess(h, estimate + N / 6) self.assertGreater(h, estimate - N / 6)
def test_dropconnect(self): ivs = spn.IVs(num_vals=2, num_vars=4) s = spn.Sum(ivs, dropconnect_keep_prob=0.5) spn.generate_weights(s) init = spn.initialize_weights(s) mask = [[0., 1., 0., 1., 1., 1., 0., 1.], [1., 0., 0., 0., 0., 0., 1., 0.]] s._create_dropout_mask = MagicMock( return_value=tf.expand_dims(tf.log(mask), 1)) val_op = tf.exp(s.get_log_value()) mask = tf.constant(mask, dtype=tf.float32) truth = tf.reduce_mean(mask, axis=-1, keepdims=True) with self.test_session() as sess: sess.run(init) dropconnect_out, truth_out = sess.run( [val_op, truth], feed_dict={ivs: -np.ones((2, 4), dtype=np.int32)}) self.assertAllClose(dropconnect_out, truth_out)
import libspn as spn import tensorflow as tf indicator_leaves = spn.IndicatorLeaf(num_vars=2, num_vals=2, name="indicator_x") # Connect first two sums to indicators of first variable sum_11 = spn.Sum((indicator_leaves, [0, 1]), name="sum_11") sum_12 = spn.Sum((indicator_leaves, [0, 1]), name="sum_12") # Connect another two sums to indicators of the second variable sum_21 = spn.Sum((indicator_leaves, [2, 3]), name="sum_21") sum_22 = spn.Sum((indicator_leaves, [2, 3]), name="sum_22") # Connect three product nodes prod_1 = spn.Product(sum_11, sum_21, name="prod_1") prod_2 = spn.Product(sum_11, sum_22, name="prod_2") prod_3 = spn.Product(sum_12, sum_22, name="prod_3") # Connect a root sum root = spn.Sum(prod_1, prod_2, prod_3, name="root") # Connect a latent indicator indicator_y = root.generate_latent_indicators( name="indicator_y") # Can be added manually # Generate weights spn.generate_weights( root, initializer=tf.initializers.random_uniform()) # Can be added manually
for i in range(stack_size): dilation_rate = 2**i x = spn.ConvProductsDepthwise(x, padding='full', kernel_size=2, strides=1, dilation_rate=dilation_rate) x = spn.LocalSums(x, num_channels=64) # Create final layer of products full_scope_prod = spn.ConvProductsDepthwise(x, padding='wicker_top', kernel_size=2, strides=1, dilation_rate=2**stack_size) class_roots = spn.ParallelSums(full_scope_prod, num_sums=num_classes) root = spn.Sum(class_roots) # Add a IndicatorLeaf node to the root as a latent class variable class_indicators = root.generate_latent_indicators() # Generate the weights for the SPN rooted at `root` spn.generate_weights(root, log=True, initializer=tf.initializers.random_uniform()) print("SPN depth: {}".format(root.get_depth())) print("Number of products layers: {}".format( root.get_num_nodes(node_type=spn.ConvProducts))) print("Number of sums layers: {}".format( root.get_num_nodes(node_type=spn.LocalSums)))
def test_generate_spn(self, num_decomps, num_subsets, num_mixtures, num_input_mixtures, input_dims, input_dist, balanced, node_type, log_weights): """A generic test for DenseSPNGenerator.""" if input_dist == spn.DenseSPNGenerator.InputDist.RAW \ and num_input_mixtures != 1: # Redundant test case, so just return return # Input parameters num_inputs = input_dims[0] num_vars = input_dims[1] num_vals = 2 printc("\n- num_inputs: %s" % num_inputs) printc("- num_vars: %s" % num_vars) printc("- num_vals: %s" % num_vals) printc("- num_decomps: %s" % num_decomps) printc("- num_subsets: %s" % num_subsets) printc("- num_mixtures: %s" % num_mixtures) printc("- input_dist: %s" % ("MIXTURE" if input_dist == spn.DenseSPNGenerator.InputDist.MIXTURE else "RAW")) printc("- balanced: %s" % balanced) printc("- num_input_mixtures: %s" % num_input_mixtures) printc("- node_type: %s" % ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE else "BLOCK" if node_type == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER")) printc("- log_weights: %s" % log_weights) # Inputs inputs = [ spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals, name=("IndicatorLeaf_%d" % (i + 1))) for i in range(num_inputs) ] gen = spn.DenseSPNGenerator(num_decomps=num_decomps, num_subsets=num_subsets, num_mixtures=num_mixtures, input_dist=input_dist, balanced=balanced, num_input_mixtures=num_input_mixtures, node_type=node_type) # Generate Sub-SPNs sub_spns = [ gen.generate(*inputs, root_name=("sub_root_%d" % (i + 1))) for i in range(3) ] # Generate random weights for the first sub-SPN with tf.name_scope("Weights"): spn.generate_weights(sub_spns[0], tf.initializers.random_uniform(0.0, 1.0), log=log_weights) # Initialize weights of the first sub-SPN sub_spn_init = spn.initialize_weights(sub_spns[0]) # Testing validity of the first sub-SPN self.assertTrue(sub_spns[0].is_valid()) # Generate value ops of the first sub-SPN sub_spn_v = sub_spns[0].get_value() sub_spn_v_log = sub_spns[0].get_log_value() # Generate path ops of the first sub-SPN sub_spn_mpe_path_gen = spn.MPEPath(log=False) sub_spn_mpe_path_gen_log = spn.MPEPath(log=True) sub_spn_mpe_path_gen.get_mpe_path(sub_spns[0]) sub_spn_mpe_path_gen_log.get_mpe_path(sub_spns[0]) sub_spn_path = [sub_spn_mpe_path_gen.counts[inp] for inp in inputs] sub_spn_path_log = [ sub_spn_mpe_path_gen_log.counts[inp] for inp in inputs ] # Collect all weight nodes of the first sub-SPN sub_spn_weight_nodes = [] def fun(node): if node.is_param: sub_spn_weight_nodes.append(node) spn.traverse_graph(sub_spns[0], fun=fun) # Generate an upper-SPN over sub-SPNs products_lower = [] for sub_spn in sub_spns: products_lower.append([v.node for v in sub_spn.values]) num_top_mixtures = [2, 1, 3] sums_lower = [] for prods, num_top_mix in zip(products_lower, num_top_mixtures): if node_type == spn.DenseSPNGenerator.NodeType.SINGLE: sums_lower.append( [spn.Sum(*prods) for _ in range(num_top_mix)]) elif node_type == spn.DenseSPNGenerator.NodeType.BLOCK: sums_lower.append( [spn.ParallelSums(*prods, num_sums=num_top_mix)]) else: sums_lower.append([ spn.SumsLayer(*prods * num_top_mix, num_or_size_sums=num_top_mix) ]) # Generate upper-SPN root = gen.generate(*list(itertools.chain(*sums_lower)), root_name="root") # Generate random weights for the SPN with tf.name_scope("Weights"): spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0), log=log_weights) # Initialize weight of the SPN spn_init = spn.initialize_weights(root) # Testing validity of the SPN self.assertTrue(root.is_valid()) # Generate value ops of the SPN spn_v = root.get_value() spn_v_log = root.get_log_value() # Generate path ops of the SPN spn_mpe_path_gen = spn.MPEPath(log=False) spn_mpe_path_gen_log = spn.MPEPath(log=True) spn_mpe_path_gen.get_mpe_path(root) spn_mpe_path_gen_log.get_mpe_path(root) spn_path = [spn_mpe_path_gen.counts[inp] for inp in inputs] spn_path_log = [spn_mpe_path_gen_log.counts[inp] for inp in inputs] # Collect all weight nodes in the SPN spn_weight_nodes = [] def fun(node): if node.is_param: spn_weight_nodes.append(node) spn.traverse_graph(root, fun=fun) # Create a session with self.test_session() as sess: # Initializing weights sess.run(sub_spn_init) sess.run(spn_init) # Generate input feed feed = np.array( list( itertools.product(range(num_vals), repeat=(num_inputs * num_vars)))) batch_size = feed.shape[0] feed_dict = {} for inp, f in zip(inputs, np.split(feed, num_inputs, axis=1)): feed_dict[inp] = f # Compute all values and paths of sub-SPN sub_spn_out = sess.run(sub_spn_v, feed_dict=feed_dict) sub_spn_out_log = sess.run(tf.exp(sub_spn_v_log), feed_dict=feed_dict) sub_spn_out_path = sess.run(sub_spn_path, feed_dict=feed_dict) sub_spn_out_path_log = sess.run(sub_spn_path_log, feed_dict=feed_dict) # Compute all values and paths of the complete SPN spn_out = sess.run(spn_v, feed_dict=feed_dict) spn_out_log = sess.run(tf.exp(spn_v_log), feed_dict=feed_dict) spn_out_path = sess.run(spn_path, feed_dict=feed_dict) spn_out_path_log = sess.run(spn_path_log, feed_dict=feed_dict) # Test if partition function of the sub-SPN and of the # complete SPN is 1.0 self.assertAlmostEqual(sub_spn_out.sum(), 1.0, places=6) self.assertAlmostEqual(sub_spn_out_log.sum(), 1.0, places=6) self.assertAlmostEqual(spn_out.sum(), 1.0, places=6) self.assertAlmostEqual(spn_out_log.sum(), 1.0, places=6) # Test if the sum of counts for each value of each variable # (6 variables, with 2 values each) = batch-size / num-vals self.assertEqual( np.sum(np.hstack(sub_spn_out_path), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(sub_spn_out_path_log), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(spn_out_path), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(spn_out_path_log), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals)
# Generates densely connected random SPNs dense_generator = spn.DenseSPNGenerator( node_type=spn.DenseSPNGenerator.NodeType.BLOCK, num_subsets=num_subsets, num_mixtures=num_mixtures, num_decomps=num_decomps, balanced=balanced, input_dist=input_dist) # Generate a dense SPN for each class class_roots = [ dense_generator.generate(leaf_indicators) for _ in range(num_classes) ] # Connect sub-SPNs to a root root = spn.Sum(*class_roots, name="RootSum") root = spn.convert_to_layer_nodes(root) # Add an IVs node to the root as a latent class variable class_indicators = root.generate_latent_indicators() # Generate the weights for the SPN rooted at `root` spn.generate_weights(root) print("SPN depth: {}".format(root.get_depth())) print("Number of products layers: {}".format( root.get_num_nodes(node_type=spn.ProductsLayer))) print("Number of sums layers: {}".format( root.get_num_nodes(node_type=spn.SumsLayer))) # Op for initializing all weights
def setup_learning(args, in_var, root): no_op = tf.constant(0) inference_type = spn.InferenceType.MARGINAL if args.value_inf_type == 'marginal' \ else spn.InferenceType.MPE mpe_state = spn.MPEState(value_inference_type=inference_type, matmul_or_conv=True) if args.supervised: # Root is provided with labels, p(x,y) labels_node = root.generate_latent_indicators(name="LabelIndicators") # Marginalized root, so without filling in labels, so p(x) = \sum_y p(x,y) root_marginalized = spn.Sum(*root.values, name="RootMarginalized", weights=root.weights) # A dummy node to get MPE state labels_no_evidence_node = root_marginalized.generate_latent_indicators( name="LabesNoEvidenceIndicators", feed=-tf.ones([tf.shape(in_var.feed)[0], 1], dtype=tf.int32)) # Get prediction from dummy node with tf.name_scope("Prediction"): logger.info("Setting up MPE state") if args.completion_by_marginal and isinstance( in_var, ContinuousLeafBase): in_var_mpe = in_var.impute_by_posterior_marginal( labels_no_evidence_node) class_mpe, = mpe_state.get_state(root_marginalized, labels_no_evidence_node) else: class_mpe, in_var_mpe = mpe_state.get_state( root_marginalized, labels_no_evidence_node, in_var) correct = tf.squeeze( tf.equal(class_mpe, tf.to_int64(labels_node.feed))) else: with tf.name_scope("Prediction"): class_mpe = correct = no_op labels_node = root_marginalized = None if args.completion_by_marginal and isinstance( in_var, ContinuousLeafBase): in_var_mpe = in_var.impute_by_posterior_marginal(root) else: in_var_mpe, = mpe_state.get_state(root, in_var) # Get the log likelihood with tf.name_scope("LogLikelihoods"): logger.info("Setting up log-likelihood") val_gen = spn.LogValue(inference_type=inference_type) labels_llh = val_gen.get_value(root) no_labels_llh = val_gen.get_value( root_marginalized) if args.supervised else labels_llh if args.learning_algo == "em": em_learning = spn.HardEMLearning( root, value_inference_type=inference_type, initial_accum_value=args.initial_accum_value, sample_winner=args.sample_path, sample_prob=args.sample_prob, use_unweighted=args.use_unweighted) accumulate = em_learning.accumulate_updates() with tf.control_dependencies([accumulate]): update_op = em_learning.update_spn() return correct, labels_node, labels_llh, no_labels_llh, update_op, class_mpe, no_op, \ no_op, in_var_mpe logger.info("Setting up GD learning") global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, args.lr_decay_steps, args.lr_decay_rate, staircase=True) learning_method = spn.LearningMethodType.DISCRIMINATIVE if args.learning_type == 'discriminative' else \ spn.LearningMethodType.GENERATIVE learning = spn.GDLearning( root, learning_task_type=spn.LearningTaskType.SUPERVISED if args.supervised else \ spn.LearningTaskType.UNSUPERVISED, learning_method=learning_method, learning_rate=learning_rate, marginalizing_root=root_marginalized, global_step=global_step) optimizer = { 'adam': tf.train.AdamOptimizer, 'rmsprop': tf.train.RMSPropOptimizer, 'amsgrad': AMSGrad, }[args.learning_algo]() minimize_op, _ = learning.learn(optimizer=optimizer) logger.info("Settting up test loss") with tf.name_scope("DeterministicLoss"): main_loss = learning.loss() regularization_loss = learning.regularization_loss() loss_per_sample = learning.loss( reduce_fn=lambda x: tf.reshape(x, (-1, ))) return correct, labels_node, main_loss, no_labels_llh, minimize_op, class_mpe, \ regularization_loss, loss_per_sample, in_var_mpe
0, 1)) # Can be added manually print(root.get_num_nodes()) print(root.get_scope()) print(root.is_valid()) SUM_CNT = 8 sum_ls = [] #iv_x = spn.IndicatorLeaf([[0,-1], [-1,-1]] ,num_vars=2, num_vals=2, name="iv_x") for i in range(SUM_CNT): iv_x = spn.IndicatorLeaf([[-1], [0]], num_vars=1, num_vals=2, name="iv_x" + str(i)) #for i in range(2): sum_x = spn.Sum((iv_x, [0, 1]), name="sum_" + str(i)) sum_x.generate_weights( tf.initializers.constant([random.random(), random.random()])) sum_ls.append(sum_x) #for i in range(2,4): # sum_x = spn.Sum((iv_x, [0,1]), name="sum_" + str(i)) # sum_x.generate_weights(tf.initializers.constant([random.random(), random.random()])) # sum_ls.append(sum_x) LAST_CNT = SUM_CNT last_node_ls = sum_ls node_type = 'Sum' LAYER_CNT = int(math.log(SUM_CNT / 2, 2)) print('Layer CNT', LAYER_CNT) for l in range(LAYER_CNT):
name="indicator_x") # Connect first two sums to indicators of first variable sums_1 = spn.ParallelSums((indicator_leaves, [0, 1]), num_sums=2, name="sums_1") # Connect another two sums to indicators of second variable sums_2 = spn.ParallelSums((indicator_leaves, [2, 3]), num_sums=2, name="sums_2") # Connect 2 * 2 == 4 product nodes prods_1 = spn.PermuteProducts(sums_1, sums_2, name="prod_1") # Connect a root sum root = spn.Sum(prods_1, name="root") # Connect a latent indicator indicator_y = root.generate_latent_indicators( name="indicator_y") # Can be added manually # Generate weights spn.generate_weights( root, initializer=tf.initializers.random_uniform()) # Can be added manually ## Inspect # Inspect print(root.get_num_nodes()) print(root.get_scope())
def dup_fun_up(inpt, *args, conc=None, tmpl_num_vars=[0], tmpl_num_vals=[0], graph_num_vars=[0], labels=[[]], tspn=None): """ Purely for template spn copying only. Supports template with multiple types of IVs. Requires that the template SPN contains only one concat node where all inputs go through. labels: (2D list) variable's numerical label, used to locate the variable's position in the big IVs. If there are multiple types of IVs, then this should be a 2D list, where each inner list is the label (starting from 0) for one type of IVs, and each outer list represents one type of IVs. """ # Know what range of indices each variable takes node, indices = inpt if node.is_op: if isinstance(node, spn.Sum): # [2:] is to skip the weights node and the explicit IVs node for this sum. return spn.Sum(*args[2:], weights=args[0]) elif isinstance(node, spn.ParSums): return spn.ParSums(*args[2:], weights=args[0], num_sums=tspn._num_mixtures) elif isinstance(node, spn.Product): return spn.Product(*args) elif isinstance(node, spn.PermProducts): return spn.PermProducts(*args) elif isinstance(node, spn.Concat): # The goal is to map from index on the template SPN's concat node to the index on # the instance SPN's concat node. # First, be able to tell which type of iv the index has ranges_tmpl = [ 0 ] # stores the start (inclusive) index of the range of indices taken by a type of iv on template SPN ranges_instance = [ 0 ] # stores the start (inclusive) index of the range of indices taken by a type of iv on instance SPN for i in range(len(tmpl_num_vars)): ranges_tmpl.append(ranges_tmpl[-1] + tmpl_num_vars[i] * tmpl_num_vals[i]) ranges_instance.append(ranges_instance[-1] + graph_num_vars[i] * tmpl_num_vals[i]) big_indices = [] for indx in indices: iv_type = -1 for i, start in enumerate(ranges_tmpl): if indx < start + tmpl_num_vars[i] * tmpl_num_vals[i]: iv_type = i break if iv_type == -1: raise ValueError( "Oops. Something wrong. Index out of range.") # Then, figure out variable index and offset (w.r.t. template Concat node) varidx = (indx - ranges_tmpl[iv_type]) // tmpl_num_vals[iv_type] offset = (indx - ranges_tmpl[iv_type] ) - varidx * tmpl_num_vals[iv_type] # THIS IS the actual position of the variable's inputs in the big Concat. varlabel = labels[iv_type][varidx] big_indices.append(ranges_instance[iv_type] + varlabel * tmpl_num_vals[iv_type] + offset) return spn.Input(conc, big_indices) elif isinstance(node, spn.Weights): return node else: raise ValueError( "Unexpected node %s. We don't intend to deal with IVs here. Please remove them from the concat." % node)
def _init_struct(self, sess, divisions=-1, num_partitions=1, partitions=None, extra_partition_multiplyer=1): """ Initialize the structure for training. (private method) sess: (tf.Session): a session that contains all weights. **kwargs: num_partitions (int): number of partitions (children for root node) If template is EdgeTemplate, then: divisions (int) number of views per place extra_partition_multiplyer (int): Used to multiply num_partitions so that more partitions are tried and ones with higher coverage are picked. """ for tspn, template in self._spns: # remove inputs; this is necessary for the duplication to work - we don't want # the indicator variables to the template spns because the instance spn has its # own indicator variable inputs. tspn._conc_inputs.set_inputs() # Create vars and maps self._catg_inputs = spn.IVs(num_vars=len(self._graph.nodes), num_vals=self._num_vals) self._conc_inputs = spn.Concat(self._catg_inputs) self._template_nodes_map = { } # map from template id to list of node lds self._node_label_map = { } # key: node id. Value: a number (0~num_nodes-1) self._label_node_map = { } # key: a number (0~num_nodes-1). Value: node id _i = 0 for nid in self._graph.nodes: self._node_label_map[nid] = _i self._label_node_map[_i] = nid _i += 1 if partitions is None: """Try partition the graph `extra_partition_multiplyer` times more than what is asked for. Then pick the top `num_partitions` with the highest coverage of the main template.""" print( "Partitioning the graph... (Selecting %d from %d attempts)" % (num_partitions, extra_partition_multiplyer * num_partitions)) partitioned_results = {} main_template = self._spns[0][1] for i in range(extra_partition_multiplyer * num_partitions): """Note: here, we only partition with the main template. The results (i.e. supergraph, unused graph) are stored and will be used later. """ unused_graph, supergraph = self._graph.partition( main_template, get_unused=True, super_node_class=self._super_node_class, super_edge_class=self._super_edge_class) if self._template_mode == NodeTemplate.code(): ## NodeTemplate coverage = len( supergraph.nodes) * main_template.size() / len( self._graph.nodes) partitioned_results[(i, coverage)] = (supergraph, unused_graph) used_coverages = set({}) for i, coverage in sorted(partitioned_results, reverse=True, key=lambda x: x[1]): used_coverages.add((i, coverage)) sys.stdout.write("%.3f " % coverage) if len(used_coverages) >= num_partitions: break sys.stdout.write("\n") """Keep partitioning the used partitions, and obtain a list of partitions in the same format as the `partitions` parameter""" partitions = [] for key in used_coverages: supergraph, unused_graph = partitioned_results[key] partition = {main_template: supergraph} # Keep partitioning the unused_graph using smaller templates for _, template in self._spns[1:]: # skip main template unused_graph_2nd, supergraph_2nd = unused_graph.partition( template, get_unused=True, super_node_class=self._super_node_class, super_edge_class=self._super_edge_class) partition[template] = supergraph_2nd unused_graph = unused_graph_2nd partitions.append(partition) """Building instance spn""" print("Building instance spn...") pspns = [] tspns = {} for template_spn, template in self._spns: tspns[template.__name__] = template_spn """Making an SPN""" """Now, partition the graph, copy structure, and connect self._catg_inputs appropriately to the network.""" # Main template partition _k = 0 self._partitions = partitions for _k, partition in enumerate(self._partitions): print("Partition %d" % (_k + 1)) nodes_covered = set({}) template_spn_roots = [] for template_spn, template in self._spns: supergraph = partition[template] print("Will duplicate %s %d times." % (template.__name__, len(supergraph.nodes))) template_spn_roots.extend( NodeTemplateInstanceSpn._duplicate_template_spns( self, tspns, template, supergraph, nodes_covered)) ## TEST CODE: COMMENT OUT WHEN ACTUALLY RUNNING # original_tspn_root = tspns[template.__name__].root # duplicated_tspn_root = template_spn_roots[-1] # original_tspn_weights = sess.run(original_tspn_root.weights.node.get_value()) # duplicated_tspn_weights = sess.run(duplicated_tspn_root.weights.node.get_value()) # print(original_tspn_weights) # print(duplicated_tspn_weights) # print(original_tspn_weights == duplicated_tspn_weights) # import pdb; pdb.set_trace() assert nodes_covered == self._graph.nodes.keys() p = spn.Product(*template_spn_roots) assert p.is_valid() pspns.append(p) # add spn for one partition ## End for loop ## # Sum up all self._root = spn.Sum(*pspns) assert self._root.is_valid() self._root.generate_weights(trainable=True) # initialize ONLY the weights node for the root sess.run(self._root.weights.node.initialize())
def test_compare_manual_conv(self, log_weights, inference_type): spn.conf.argmax_zero = True grid_dims = [2, 2] nrows, ncols = grid_dims num_vals = 4 batch_size = 256 num_vars = grid_dims[0] * grid_dims[1] indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals) num_sums = 32 weights = spn.Weights(num_weights=num_vals, num_sums=num_sums * num_vars, initializer=tf.initializers.random_uniform(), log=log_weights) weights_per_cell = tf.split(weights.variable, num_or_size_splits=num_vars) parsums = [] for row in range(nrows): for col in range(ncols): indices = list( range(row * (ncols * num_vals) + col * num_vals, row * (ncols * num_vals) + (col + 1) * num_vals)) parsums.append( spn.ParallelSums((indicator_leaf, indices), num_sums=num_sums)) parsum_concat = spn.Concat(*parsums, name="ParSumConcat") convsum = spn.LocalSums(indicator_leaf, num_channels=num_sums, weights=weights, spatial_dim_sizes=grid_dims) prod00_conv = spn.PermuteProducts( (convsum, list(range(num_sums))), (convsum, list(range(num_sums, num_sums * 2))), name="Prod00") prod01_conv = spn.PermuteProducts( (convsum, list(range(num_sums * 2, num_sums * 3))), (convsum, list(range(num_sums * 3, num_sums * 4))), name="Prod01") sum00_conv = spn.ParallelSums(prod00_conv, num_sums=2) sum01_conv = spn.ParallelSums(prod01_conv, num_sums=2) prod10_conv = spn.PermuteProducts(sum00_conv, sum01_conv, name="Prod10") conv_root = spn.Sum(prod10_conv) prod00_pars = spn.PermuteProducts( (parsum_concat, list(range(num_sums))), (parsum_concat, list(range(num_sums, num_sums * 2)))) prod01_pars = spn.PermuteProducts( (parsum_concat, list(range(num_sums * 2, num_sums * 3))), (parsum_concat, list(range(num_sums * 3, num_sums * 4)))) sum00_pars = spn.ParallelSums(prod00_pars, num_sums=2) sum01_pars = spn.ParallelSums(prod01_pars, num_sums=2) prod10_pars = spn.PermuteProducts(sum00_pars, sum01_pars) parsum_root = spn.Sum(prod10_pars) node_pairs = [(sum00_conv, sum00_pars), (sum01_conv, sum01_pars), (conv_root, parsum_root)] self.assertTrue(conv_root.is_valid()) self.assertTrue(parsum_root.is_valid()) self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope()) spn.generate_weights(conv_root, log=log_weights, initializer=tf.initializers.random_uniform()) spn.generate_weights(parsum_root, log=log_weights, initializer=tf.initializers.random_uniform()) convsum.set_weights(weights) copy_weight_ops = [] parsum_weight_nodes = [] for p, w in zip(parsums, weights_per_cell): copy_weight_ops.append(tf.assign(p.weights.node.variable, w)) parsum_weight_nodes.append(p.weights.node) for wc, wp in node_pairs: copy_weight_ops.append( tf.assign(wp.weights.node.variable, wc.weights.node.variable)) copy_weights_op = tf.group(*copy_weight_ops) init_conv = spn.initialize_weights(conv_root) init_parsum = spn.initialize_weights(parsum_root) path_conv = spn.MPEPath(value_inference_type=inference_type) path_conv.get_mpe_path(conv_root) path_parsum = spn.MPEPath(value_inference_type=inference_type) path_parsum.get_mpe_path(parsum_root) indicator_counts_parsum = path_parsum.counts[indicator_leaf] indicator_counts_convsum = path_conv.counts[indicator_leaf] weight_counts_parsum = tf.concat( [path_parsum.counts[w] for w in parsum_weight_nodes], axis=1) weight_counts_conv = path_conv.counts[weights] weight_parsum_concat = tf.concat( [w.variable for w in parsum_weight_nodes], axis=0) root_val_parsum = parsum_root.get_log_value( ) #path_parsum.value.values[parsum_root] root_val_conv = conv_root.get_log_value( ) #path_conv.value.values[conv_root] parsum_counts = path_parsum.counts[parsum_concat] conv_counts = path_conv.counts[convsum] indicator_feed = np.random.randint(-1, 2, size=batch_size * num_vars)\ .reshape((batch_size, num_vars)) with tf.Session() as sess: sess.run([init_conv, init_parsum]) sess.run(copy_weights_op) indicator_counts_conv_out, indicator_counts_parsum_out = sess.run( [indicator_counts_convsum, indicator_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) root_conv_value_out, root_parsum_value_out = sess.run( [root_val_conv, root_val_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_counts_conv_out, weight_counts_parsum_out = sess.run( [weight_counts_conv, weight_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_value_conv_out, weight_value_parsum_out = sess.run( [convsum.weights.node.variable, weight_parsum_concat]) parsum_counts_out, conv_counts_out = sess.run( [parsum_counts, conv_counts], feed_dict={indicator_leaf: indicator_feed}) parsum_concat_val, convsum_val = sess.run( [ path_parsum.value.values[parsum_concat], path_conv.value.values[convsum] ], feed_dict={indicator_leaf: indicator_feed}) self.assertAllClose(convsum_val, parsum_concat_val) self.assertAllClose(weight_value_conv_out, weight_value_parsum_out) self.assertAllClose(root_conv_value_out, root_parsum_value_out) self.assertAllEqual(indicator_counts_conv_out, indicator_counts_parsum_out) self.assertAllEqual(parsum_counts_out, conv_counts_out) self.assertAllEqual(weight_counts_conv_out, weight_counts_parsum_out)