def par_sums(inputs, indices, latent_indicators, num_sums, inf_type=None, log=True, output=None): if indices is None: inputs = [inputs] else: inputs = [(inputs, indices)] # Generate a single ParallelSums node, modeling 'num_sums' sum nodes # within, connecting it to inputs and latent_indicators s = spn.ParallelSums(*inputs, num_sums=num_sums, latent_indicators=latent_indicators[-1]) # Generate weights of the ParallelSums node weights = s.generate_weights() # Connect the ParallelSums nodes to a single root Sum node and generate # its weights root = spn.Sum(s) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_op = [mpe_path_gen.counts[weights]] return spn.initialize_weights(root), path_op
def test_mpe_path(self): # Generate SPN model = spn.Poon11NaiveMixtureModel() model.build() # Add ops init = spn.initialize_weights(model.root) mpe_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE, log=False) mpe_path_gen_log = spn.MPEPath( value_inference_type=spn.InferenceType.MPE, log=True) mpe_path_gen.get_mpe_path(model.root) mpe_path_gen_log.get_mpe_path(model.root) # Run with self.test_session() as sess: init.run() out = sess.run(mpe_path_gen.counts[model.latent_indicators], feed_dict={model.latent_indicators: model.feed}) out_log = sess.run( mpe_path_gen_log.counts[model.latent_indicators], feed_dict={model.latent_indicators: model.feed}) true_latent_indicators_counts = np.array( [[0., 1., 1., 0.], [0., 1., 1., 0.], [0., 1., 0., 1.], [1., 0., 1., 0.], [1., 0., 1., 0.], [1., 0., 0., 1.], [0., 1., 1., 0.], [0., 1., 1., 0.], [0., 1., 0., 1.]], dtype=spn.conf.dtype.as_numpy_dtype) np.testing.assert_array_equal(out, true_latent_indicators_counts) np.testing.assert_array_equal(out_log, true_latent_indicators_counts)
def poons_multi(inputs, num_vals, num_mixtures, num_subsets, inf_type, log=False, output=None): # Build a POON-like network with multi-op nodes subsets = [ spn.ParSums((inputs, list(range(i * num_vals, (i + 1) * num_vals))), num_sums=num_mixtures) for i in range(num_subsets) ] products = spn.PermProducts(*subsets) root = spn.Sum(products, name="root") # Generate dense SPN and all weights in the network spn.generate_weights(root) # Generate path ops based on inf_type and log if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_ops = [ mpe_path_gen.counts[inp] for inp in (inputs if isinstance(inputs, list) else [inputs]) ] return root, spn.initialize_weights(root), path_ops
def sum(inputs, indices, latent_indicators, num_sums, inf_type=None, log=False, output=None): if indices is None: inputs = [inputs] else: inputs = [(inputs, indices)] # Generate 'num_sums' Sum nodes, connecting each to inputs and latent_indicators s = [] weights = [] for i in range(0, num_sums): s = s + [spn.Sum(*inputs, latent_indicators=latent_indicators[i])] weights = weights + [s[-1].generate_weights()] # Connect all sum nodes to a single root Sum node and generate its weights root = spn.Sum(*s) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_ops = [mpe_path_gen.counts[w] for w in weights] return spn.initialize_weights(root), path_ops
def sums(inputs, indices, ivs, num_sums, inf_type=None, log=True, output=None): if indices is None: inputs = [inputs for _ in range(num_sums)] else: inputs = [(inputs, indices) for _ in range(num_sums)] # Generate a single Sums node, modeling 'num_sums' sum nodes within, # connecting it to inputs and ivs s = spn.Sums(*inputs, num_sums=num_sums, ivs=ivs[-1]) # Generate weights of the Sums node weights = s.generate_weights() # Connect the Sums nodes to a single root Sum node and generate its weights root = spn.Sum(s) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_op = [mpe_path_gen.counts[weights]] return spn.initialize_weights(root), path_op
def products(inputs, num_inputs, num_input_cols, num_prods, inf_type, indices=None, log=False, output=None): p = [] # Generate 'len(inputs)' Products node, modelling 'n_prods' ∈ 'num_prods' # products within each for inps, n_inp_cols, n_prods in zip(inputs, num_input_cols, num_prods): num_inputs = len(inps) # Create permuted indices based on number and size of inps inds = map(int, np.arange(n_inp_cols)) permuted_inds = list(product(inds, repeat=num_inputs)) permuted_inds_list = [list(elem) for elem in permuted_inds] permuted_inds_list_of_list = [] for elem in permuted_inds_list: permuted_inds_list_of_list.append( [elem[i:i + 1] for i in range(0, len(elem), 1)]) # Create inputs-list by combining inps and indices permuted_inputs = [] for indices in permuted_inds_list_of_list: permuted_inputs.append([tuple(i) for i in zip(inps, indices)]) permuted_inputs = list(chain.from_iterable(permuted_inputs)) # Generate a single Products node, modeling 'n_prods' product nodes # within, connecting it to inputs p = p + [spn.Products(*permuted_inputs, num_prods=n_prods)] # Connect all product nodes to a single root Sum node and generate its # weights root = spn.Sum(*p) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_ops = [ mpe_path_gen.counts[inp] for inp in list(chain.from_iterable(inputs)) ] return spn.initialize_weights(root), path_ops
def perm_products(inputs, num_inputs, num_input_cols, num_prods, inf_type, indices=None, log=False, output=None): if indices is not None: # Create inputs list with indices inputs_list = [[(inp, ind) for inp, ind in zip(inps, inds)] for inps, inds in zip(inputs, indices)] else: inputs_list = inputs if isinstance(inputs, list): # Is a list of ContVars inputs - Multiple inputs # Generate 'len(inputs)' PermProducts nodes, modeling 'n_prods' products # within each p = [spn.PermProducts(*inps) for inps in inputs] else: # Is a single input of type ContVars - A single input num_inputs_array = np.array(num_inputs) num_input_cols_array = np.array(num_input_cols) num_cols = num_input_cols[0] num_vars = int(np.sum(num_inputs_array * num_input_cols_array)) indices_list = [list(range(i, i+num_cols)) for i in range(0, num_vars, num_cols)] num_inputs_cumsum = np.cumsum(num_inputs_array).tolist() num_inputs_cumsum.insert(0, 0) inputs_list = [[(inputs, inds) for inds in indices_list[start:stop]] for start, stop in zip(num_inputs_cumsum[:-1], num_inputs_cumsum[1:])] # Generate 'len(inputs)' PermProducts nodes, modeling 'n_prods' # products within each, and inputs for each node emination from a # commoninput source p = [spn.PermProducts(*inps) for inps in inputs_list] # Connect all PermProducts nodes to a single root Sum node and generate # its weights root = spn.Sum(*p) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) if isinstance(inputs, list): # Is a list of ContVars inputs - Multiple inputs path_ops = [mpe_path_gen.counts[inp] for inp in list(chain.from_iterable(inputs))] else: # Is a single input of type ContVars - A single input path_ops = mpe_path_gen.counts[inputs] return spn.initialize_weights(root), path_ops
def _build_op(self, inputs, placeholders, conf): """ Creates the graph using only ParSum nodes """ # TODO make sure the ivs are correct sum_indices, weights, ivs = inputs.indices, inputs.weights, None log, inf_type = conf.log, conf.inf_type weights = np.split( weights, np.cumsum([len(ind) * inputs.num_parallel for ind in sum_indices])[:-1]) parallel_sum_nodes = [] for ind in sum_indices: parallel_sum_nodes.append( spn.ParSums((placeholders[0], ind), num_sums=inputs.num_parallel)) weight_nodes = [ self._generate_weights(node, w.tolist()) for node, w in zip(parallel_sum_nodes, weights) ] if ivs: [s.set_ivs(iv) for s, iv in zip(parallel_sum_nodes, ivs)] root = spn.Sum(*parallel_sum_nodes) self._generate_weights(root) mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=log) mpe_path_gen.get_mpe_path(root) path_op = [mpe_path_gen.counts[w] for w in weight_nodes] input_counts = [mpe_path_gen.counts[inp] for inp in placeholders] return tf.tuple( path_op + input_counts)[:len(path_op)], self._initialize_from(root)
def dense_block(inputs, num_decomps, num_subsets, num_mixtures, num_input_mixtures, balanced, input_dist, inf_type, log=False): # Set node-type as single-node node_type = spn.DenseSPNGenerator.NodeType.BLOCK # Create a dense generator gen = spn.DenseSPNGenerator( num_decomps=num_decomps, num_subsets=num_subsets, num_mixtures=num_mixtures, num_input_mixtures=num_input_mixtures, balanced=balanced, node_type=node_type, input_dist=(spn.DenseSPNGenerator.InputDist.RAW if input_dist is "RAW" else spn.DenseSPNGenerator.InputDist.MIXTURE)) # Generate a dense SPN, with single-op nodes, and all weights in the network root = gen.generate(inputs, root_name="root") spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0)) # Generate path ops based on inf_type and log if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) path_ops = [ mpe_path_gen.counts[inp] for inp in (inputs if isinstance(inputs, list) else [inputs]) ] return root, spn.initialize_weights(root), path_ops
def _build_op(self, inputs, placeholders, conf): # TODO make sure the ivs are correct sum_indices, weights, ivs = inputs.indices, inputs.weights, None log, inf_type = conf.log, conf.inf_type repeated_inputs = [] repeated_sum_sizes = [] offset = 0 for ind in sum_indices: # Indices are given by looking at the sizes of the sums size = len(ind) repeated_inputs.extend([(placeholders[0], ind) for _ in range(inputs.num_parallel)]) repeated_sum_sizes.extend( [size for _ in range(inputs.num_parallel)]) offset += size # Globally configure to add up the sums before passing on the values to children spn.conf.sumslayer_count_sum_strategy = self.sum_count_strategy sums_layer = spn.SumsLayer(*repeated_inputs, num_or_size_sums=repeated_sum_sizes) weight_node = self._generate_weights(sums_layer, weights) if ivs: sums_layer.set_ivs(*ivs) # Connect a single sum to group outcomes root = spn.Sum(sums_layer) self._generate_weights(root) # Then build MPE path Ops mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=log) mpe_path_gen.get_mpe_path(root) path_op = [ tf.tuple([ mpe_path_gen.counts[weight_node], mpe_path_gen.counts[placeholders[0]] ])[0] ] return path_op, self._initialize_from(root)
def test_sumslayer_mpe_path(self, input_sizes, sum_sizes, latent_indicators, log, same_inputs, inf_type, count_strategy, indices, use_unweighted): spn.conf.argmax_zero = True # Set some defaults if (1 in sum_sizes or 1 in input_sizes or np.all(np.equal(sum_sizes, sum_sizes[0]))) \ and use_unweighted: # There is not a clean way to solve the issue avoided here. It has to do with floating # point errors in numpy vs. tf, leading to unpredictable behavior of argmax. # Unweighted values take away any weighting randomness, so the argmax will obtain some # values that are very likely to be equal up to these floating point errors. Hence, # we just set use_unweighted to False if the sum size or input size equals 1 (which is # typically when the values are 'pseudo'-equal) return None batch_size = 32 factor = 10 # Configure count strategy spn.conf.sumslayer_count_sum_strategy = count_strategy feed_dict, indices, input_nodes, input_tuples, latent_indicators, values, weights, root_weights = \ self.sumslayer_prepare_common( batch_size, factor, indices, input_sizes, latent_indicators, same_inputs, sum_sizes) root_weights_np = np.ones_like( root_weights) if use_unweighted and log else root_weights weight_counts, latent_indicators_counts, value_counts = sumslayer_mpe_path_numpy( values, indices, weights, None if not latent_indicators else latent_indicators, sum_sizes, inf_type, root_weights_np) # Build graph init, latent_indicators_nodes, root, weight_node = self.build_sumslayer_common( feed_dict, input_tuples, latent_indicators, sum_sizes, weights, root_weights) # Then build MPE path Ops mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True, use_unweighted=use_unweighted) mpe_path_gen.get_mpe_path(root) path_op = [ mpe_path_gen.counts[node] for node in [weight_node] + input_nodes + latent_indicators_nodes ] # Run graph and do some post-processing with self.test_session() as sess: sess.run(init) out = sess.run(path_op, feed_dict=feed_dict) if latent_indicators: latent_indicators_counts_out = out[-1] latent_indicators_counts_out = np.split( latent_indicators_counts_out, indices_or_sections=len(sum_sizes), axis=1) latent_indicators_counts_out = [ np.squeeze(iv, axis=1)[:, :size] for iv, size in zip( latent_indicators_counts_out, sum_sizes) ] out = out[:-1] weight_counts_out, *input_counts_out = out weight_counts_out = np.split(weight_counts_out, indices_or_sections=len(sum_sizes), axis=1) weight_counts_out = [ np.squeeze(w, axis=1)[:, :size] for w, size in zip(weight_counts_out, sum_sizes) ] if same_inputs: value_counts = [np.sum(value_counts, axis=0)] # Test outputs [ self.assertAllClose(inp_count_out, inp_count) for inp_count_out, inp_count in zip(input_counts_out, value_counts) ] [ self.assertAllClose(w_out, w_out_truth) for w_out, w_out_truth in zip(weight_counts_out, weight_counts) ] if latent_indicators: [ self.assertAllClose(iv_out, iv_true_out) for iv_out, iv_true_out in zip(latent_indicators_counts_out, latent_indicators_counts) ]
def products_layer(inputs, num_inputs, num_input_cols, num_prods, inf_type, indices=None, log=False, output=None): products_inputs = [] num_or_size_prods = [] if isinstance(inputs, list): # Is a list of RawLeaf inputs - Multiple inputs for inps, n_inp_cols, n_prods in zip(inputs, num_input_cols, num_prods): num_inputs = len(inps) # Create permuted indices based on number and size of inputs inds = map(int, np.arange(n_inp_cols)) permuted_inds = list(product(inds, repeat=num_inputs)) permuted_inds_list = [list(elem) for elem in permuted_inds] permuted_inds_list_of_list = [] for elem in permuted_inds_list: permuted_inds_list_of_list.append( [elem[i:i + 1] for i in range(0, len(elem), 1)]) # Create inputs list by combining inputs and indices permuted_inputs = [] for indices in permuted_inds_list_of_list: permuted_inputs.append( [tuple(i) for i in zip(inps, indices)]) products_inputs += list(chain.from_iterable(permuted_inputs)) # Create products-size list num_or_size_prods += [num_inputs] * n_prods else: # Is a single input of type RawLeaf - A single input outer_offset = 0 permuted_inds_list = [] for n_inps, n_inp_cols in zip(num_inputs, num_input_cols): # Create permuted indices based on number and size of inputs inds = map(int, np.arange(n_inp_cols)) permuted_inds = list(product(inds, repeat=n_inps)) offsets = \ np.array(list(range(0, (n_inps * n_inp_cols), n_inp_cols))) \ + outer_offset outer_offset += n_inps * n_inp_cols for perm_inds in permuted_inds: permuted_inds_list.append([ p_ind + off for p_ind, off in zip(list(perm_inds), offsets) ]) # Content of list object 'perm_inds' needs to be of type int, if not # input_parser in Input class complains products_inputs = [(inputs, list(map(int, perm_inds))) for perm_inds in permuted_inds_list] num_or_size_prods = [ len(perm_inds) for perm_inds in permuted_inds_list ] # Generate a single ProductsLayer node, modeling 'sum(num_prods)' products # within, connecting it to inputs p = spn.ProductsLayer(*products_inputs, num_or_size_prods=num_or_size_prods) # Connect all product nodes to a single root Sum node and generate its # weights root = spn.Sum(p) root.generate_weights() if log: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True) else: mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False) mpe_path_gen.get_mpe_path(root) if isinstance(inputs, list): path_ops = [ mpe_path_gen.counts[inp] for inp in list(chain.from_iterable(inputs)) ] else: path_ops = mpe_path_gen.counts[inputs] return spn.initialize_weights(root), path_ops
def test_generate_spn(self, num_decomps, num_subsets, num_mixtures, num_input_mixtures, input_dims, input_dist, balanced, node_type, log_weights): """A generic test for DenseSPNGenerator.""" if input_dist == spn.DenseSPNGenerator.InputDist.RAW \ and num_input_mixtures != 1: # Redundant test case, so just return return # Input parameters num_inputs = input_dims[0] num_vars = input_dims[1] num_vals = 2 printc("\n- num_inputs: %s" % num_inputs) printc("- num_vars: %s" % num_vars) printc("- num_vals: %s" % num_vals) printc("- num_decomps: %s" % num_decomps) printc("- num_subsets: %s" % num_subsets) printc("- num_mixtures: %s" % num_mixtures) printc("- input_dist: %s" % ("MIXTURE" if input_dist == spn.DenseSPNGenerator.InputDist.MIXTURE else "RAW")) printc("- balanced: %s" % balanced) printc("- num_input_mixtures: %s" % num_input_mixtures) printc("- node_type: %s" % ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE else "BLOCK" if node_type == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER")) printc("- log_weights: %s" % log_weights) # Inputs inputs = [ spn.IVs(num_vars=num_vars, num_vals=num_vals, name=("IVs_%d" % (i + 1))) for i in range(num_inputs) ] gen = spn.DenseSPNGenerator(num_decomps=num_decomps, num_subsets=num_subsets, num_mixtures=num_mixtures, input_dist=input_dist, balanced=balanced, num_input_mixtures=num_input_mixtures, node_type=node_type) # Generate Sub-SPNs sub_spns = [ gen.generate(*inputs, root_name=("sub_root_%d" % (i + 1))) for i in range(3) ] # Generate random weights for the first sub-SPN with tf.name_scope("Weights"): spn.generate_weights(sub_spns[0], tf.initializers.random_uniform(0.0, 1.0), log=log_weights) # Initialize weights of the first sub-SPN sub_spn_init = spn.initialize_weights(sub_spns[0]) # Testing validity of the first sub-SPN self.assertTrue(sub_spns[0].is_valid()) # Generate value ops of the first sub-SPN sub_spn_v = sub_spns[0].get_value() sub_spn_v_log = sub_spns[0].get_log_value() # Generate path ops of the first sub-SPN sub_spn_mpe_path_gen = spn.MPEPath(log=False) sub_spn_mpe_path_gen_log = spn.MPEPath(log=True) sub_spn_mpe_path_gen.get_mpe_path(sub_spns[0]) sub_spn_mpe_path_gen_log.get_mpe_path(sub_spns[0]) sub_spn_path = [sub_spn_mpe_path_gen.counts[inp] for inp in inputs] sub_spn_path_log = [ sub_spn_mpe_path_gen_log.counts[inp] for inp in inputs ] # Collect all weight nodes of the first sub-SPN sub_spn_weight_nodes = [] def fun(node): if node.is_param: sub_spn_weight_nodes.append(node) spn.traverse_graph(sub_spns[0], fun=fun) # Generate an upper-SPN over sub-SPNs products_lower = [] for sub_spn in sub_spns: products_lower.append([v.node for v in sub_spn.values]) num_top_mixtures = [2, 1, 3] sums_lower = [] for prods, num_top_mix in zip(products_lower, num_top_mixtures): if node_type == spn.DenseSPNGenerator.NodeType.SINGLE: sums_lower.append( [spn.Sum(*prods) for _ in range(num_top_mix)]) elif node_type == spn.DenseSPNGenerator.NodeType.BLOCK: sums_lower.append([spn.ParSums(*prods, num_sums=num_top_mix)]) else: sums_lower.append([ spn.SumsLayer(*prods * num_top_mix, num_or_size_sums=num_top_mix) ]) # Generate upper-SPN root = gen.generate(*list(itertools.chain(*sums_lower)), root_name="root") # Generate random weights for the SPN with tf.name_scope("Weights"): spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0), log=log_weights) # Initialize weight of the SPN spn_init = spn.initialize_weights(root) # Testing validity of the SPN self.assertTrue(root.is_valid()) # Generate value ops of the SPN spn_v = root.get_value() spn_v_log = root.get_log_value() # Generate path ops of the SPN spn_mpe_path_gen = spn.MPEPath(log=False) spn_mpe_path_gen_log = spn.MPEPath(log=True) spn_mpe_path_gen.get_mpe_path(root) spn_mpe_path_gen_log.get_mpe_path(root) spn_path = [spn_mpe_path_gen.counts[inp] for inp in inputs] spn_path_log = [spn_mpe_path_gen_log.counts[inp] for inp in inputs] # Collect all weight nodes in the SPN spn_weight_nodes = [] def fun(node): if node.is_param: spn_weight_nodes.append(node) spn.traverse_graph(root, fun=fun) # Create a session with self.test_session() as sess: # Initializing weights sess.run(sub_spn_init) sess.run(spn_init) # Generate input feed feed = np.array( list( itertools.product(range(num_vals), repeat=(num_inputs * num_vars)))) batch_size = feed.shape[0] feed_dict = {} for inp, f in zip(inputs, np.split(feed, num_inputs, axis=1)): feed_dict[inp] = f # Compute all values and paths of sub-SPN sub_spn_out = sess.run(sub_spn_v, feed_dict=feed_dict) sub_spn_out_log = sess.run(tf.exp(sub_spn_v_log), feed_dict=feed_dict) sub_spn_out_path = sess.run(sub_spn_path, feed_dict=feed_dict) sub_spn_out_path_log = sess.run(sub_spn_path_log, feed_dict=feed_dict) # Compute all values and paths of the complete SPN spn_out = sess.run(spn_v, feed_dict=feed_dict) spn_out_log = sess.run(tf.exp(spn_v_log), feed_dict=feed_dict) spn_out_path = sess.run(spn_path, feed_dict=feed_dict) spn_out_path_log = sess.run(spn_path_log, feed_dict=feed_dict) # Test if partition function of the sub-SPN and of the # complete SPN is 1.0 self.assertAlmostEqual(sub_spn_out.sum(), 1.0, places=6) self.assertAlmostEqual(sub_spn_out_log.sum(), 1.0, places=6) self.assertAlmostEqual(spn_out.sum(), 1.0, places=6) self.assertAlmostEqual(spn_out_log.sum(), 1.0, places=6) # Test if the sum of counts for each value of each variable # (6 variables, with 2 values each) = batch-size / num-vals self.assertEqual( np.sum(np.hstack(sub_spn_out_path), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(sub_spn_out_path_log), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(spn_out_path), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(spn_out_path_log), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals)