def test_small_spn(self): num_vars = 13 indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=num_vars) randomize = BlockRandomDecompositions(indicator_leaf, num_decomps=2) p0 = BlockPermuteProduct(randomize, num_factors=4) s0 = BlockSum(p0, num_sums_per_block=3) p1 = BlockPermuteProduct(s0, num_factors=2) s1 = BlockSum(p1, num_sums_per_block=3) p2 = BlockReduceProduct(s1, num_factors=2) m = BlockMergeDecompositions(p2, num_decomps=1) root = BlockRootSum(m) latent = root.generate_latent_indicators(name="Latent") spn.generate_weights(root, initializer=tf.initializers.random_uniform()) valgen = spn.LogValue() val = valgen.get_value(root) logsum = tf.reduce_logsumexp(val) num_possibilities = 2**num_vars nums = np.arange(num_possibilities).reshape((num_possibilities, 1)) powers = 2**np.arange(num_vars).reshape((1, num_vars)) leaf_feed = np.bitwise_and(nums, powers) // powers with self.test_session() as sess: sess.run(spn.initialize_weights(root)) out = sess.run( logsum, { indicator_leaf: leaf_feed, latent: -np.ones((leaf_feed.shape[0], 1), dtype=np.int32) }) self.assertAllClose(out, 0.0)
def test_compute_value_sum(self, grid_size): indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=grid_size**2) convsum = ConvSums(indicator_leaf, spatial_dim_sizes=[grid_size, grid_size], num_channels=4) convsum2 = ConvSums(indicator_leaf, spatial_dim_sizes=[grid_size, grid_size], num_channels=4) dense_generator = spn.DenseSPNGenerator( num_mixtures=4, num_subsets=4, num_decomps=1, input_dist=spn.DenseSPNGenerator.InputDist.MIXTURE) root = dense_generator.generate(convsum, convsum2) spn.generate_weights(root, initializer=tf.initializers.random_uniform()) init = spn.initialize_weights(root) num_possibilities = 2**(grid_size**2) nums = np.arange(num_possibilities).reshape((num_possibilities, 1)) powers = 2**np.arange(grid_size**2).reshape((1, grid_size**2)) indicator_feed = np.bitwise_and(nums, powers) // powers value_op = spn.LogValue(spn.InferenceType.MARGINAL).get_value(root) value_op_sum = tf.reduce_logsumexp(value_op) with self.test_session() as sess: sess.run(init) root_sum = sess.run(value_op_sum, feed_dict={indicator_leaf: indicator_feed}) print(indicator_feed[:10]) self.assertAllClose(root_sum, 0.0)
def generate_random_weights(self, trainable=True): """ Generates random weights for this spn. """ weight_init_value = spn.ValueType.RANDOM_UNIFORM( self._weight_init_min, self._weight_init_max) spn.generate_weights(self._root, init_value=weight_init_value, trainable=trainable)
def test_dropconnect(self): ivs = spn.IVs(num_vals=2, num_vars=4) s = spn.Sum(ivs, dropconnect_keep_prob=0.5) spn.generate_weights(s) init = spn.initialize_weights(s) mask = [[0., 1., 0., 1., 1., 1., 0., 1.], [1., 0., 0., 0., 0., 0., 1., 0.]] s._create_dropout_mask = MagicMock( return_value=tf.expand_dims(tf.log(mask), 1)) val_op = tf.exp(s.get_log_value()) mask = tf.constant(mask, dtype=tf.float32) truth = tf.reduce_mean(mask, axis=-1, keepdims=True) with self.test_session() as sess: sess.run(init) dropconnect_out, truth_out = sess.run( [val_op, truth], feed_dict={ivs: -np.ones((2, 4), dtype=np.int32)}) self.assertAllClose(dropconnect_out, truth_out)
sum_12 = spn.Sum((indicator_leaves, [0, 1]), name="sum_12") # Connect another two sums to indicators of the second variable sum_21 = spn.Sum((indicator_leaves, [2, 3]), name="sum_21") sum_22 = spn.Sum((indicator_leaves, [2, 3]), name="sum_22") # Connect three product nodes prod_1 = spn.Product(sum_11, sum_21, name="prod_1") prod_2 = spn.Product(sum_11, sum_22, name="prod_2") prod_3 = spn.Product(sum_12, sum_22, name="prod_3") # Connect a root sum root = spn.Sum(prod_1, prod_2, prod_3, name="root") # Connect a latent indicator indicator_y = root.generate_latent_indicators( name="indicator_y") # Can be added manually # Generate weights spn.generate_weights( root, initializer=tf.initializers.random_uniform()) # Can be added manually # Inspect print(root.get_num_nodes()) print(root.get_scope()) print(root.is_valid()) # Visualize SPN graph spn.display_spn_graph(root)
x = spn.LocalSums(x, num_channels=64) # Create final layer of products full_scope_prod = spn.ConvProductsDepthwise(x, padding='wicker_top', kernel_size=2, strides=1, dilation_rate=2**stack_size) class_roots = spn.ParallelSums(full_scope_prod, num_sums=num_classes) root = spn.Sum(class_roots) # Add a IndicatorLeaf node to the root as a latent class variable class_indicators = root.generate_latent_indicators() # Generate the weights for the SPN rooted at `root` spn.generate_weights(root, log=True, initializer=tf.initializers.random_uniform()) print("SPN depth: {}".format(root.get_depth())) print("Number of products layers: {}".format( root.get_num_nodes(node_type=spn.ConvProducts))) print("Number of sums layers: {}".format( root.get_num_nodes(node_type=spn.LocalSums))) # In[4]: spn.display_tf_graph() # # <h3 id="Defining-the-TensorFlow-graph">Defining the TensorFlow graph<a class="anchor-link" href="#Defining-the-TensorFlow-graph">¶</a></h3><p>Now that we have defined the SPN graph we can declare the TensorFlow operations needed for training and evaluation. The <code>MPEState</code> # class can be used to find the MPE state of any node in the graph. In
def test_generate_spn(self, num_decomps, num_subsets, num_mixtures, num_input_mixtures, input_dims, input_dist, balanced, node_type, log_weights): """A generic test for DenseSPNGenerator.""" if input_dist == spn.DenseSPNGenerator.InputDist.RAW \ and num_input_mixtures != 1: # Redundant test case, so just return return # Input parameters num_inputs = input_dims[0] num_vars = input_dims[1] num_vals = 2 printc("\n- num_inputs: %s" % num_inputs) printc("- num_vars: %s" % num_vars) printc("- num_vals: %s" % num_vals) printc("- num_decomps: %s" % num_decomps) printc("- num_subsets: %s" % num_subsets) printc("- num_mixtures: %s" % num_mixtures) printc("- input_dist: %s" % ("MIXTURE" if input_dist == spn.DenseSPNGenerator.InputDist.MIXTURE else "RAW")) printc("- balanced: %s" % balanced) printc("- num_input_mixtures: %s" % num_input_mixtures) printc("- node_type: %s" % ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE else "BLOCK" if node_type == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER")) printc("- log_weights: %s" % log_weights) # Inputs inputs = [ spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals, name=("IndicatorLeaf_%d" % (i + 1))) for i in range(num_inputs) ] gen = spn.DenseSPNGenerator(num_decomps=num_decomps, num_subsets=num_subsets, num_mixtures=num_mixtures, input_dist=input_dist, balanced=balanced, num_input_mixtures=num_input_mixtures, node_type=node_type) # Generate Sub-SPNs sub_spns = [ gen.generate(*inputs, root_name=("sub_root_%d" % (i + 1))) for i in range(3) ] # Generate random weights for the first sub-SPN with tf.name_scope("Weights"): spn.generate_weights(sub_spns[0], tf.initializers.random_uniform(0.0, 1.0), log=log_weights) # Initialize weights of the first sub-SPN sub_spn_init = spn.initialize_weights(sub_spns[0]) # Testing validity of the first sub-SPN self.assertTrue(sub_spns[0].is_valid()) # Generate value ops of the first sub-SPN sub_spn_v = sub_spns[0].get_value() sub_spn_v_log = sub_spns[0].get_log_value() # Generate path ops of the first sub-SPN sub_spn_mpe_path_gen = spn.MPEPath(log=False) sub_spn_mpe_path_gen_log = spn.MPEPath(log=True) sub_spn_mpe_path_gen.get_mpe_path(sub_spns[0]) sub_spn_mpe_path_gen_log.get_mpe_path(sub_spns[0]) sub_spn_path = [sub_spn_mpe_path_gen.counts[inp] for inp in inputs] sub_spn_path_log = [ sub_spn_mpe_path_gen_log.counts[inp] for inp in inputs ] # Collect all weight nodes of the first sub-SPN sub_spn_weight_nodes = [] def fun(node): if node.is_param: sub_spn_weight_nodes.append(node) spn.traverse_graph(sub_spns[0], fun=fun) # Generate an upper-SPN over sub-SPNs products_lower = [] for sub_spn in sub_spns: products_lower.append([v.node for v in sub_spn.values]) num_top_mixtures = [2, 1, 3] sums_lower = [] for prods, num_top_mix in zip(products_lower, num_top_mixtures): if node_type == spn.DenseSPNGenerator.NodeType.SINGLE: sums_lower.append( [spn.Sum(*prods) for _ in range(num_top_mix)]) elif node_type == spn.DenseSPNGenerator.NodeType.BLOCK: sums_lower.append( [spn.ParallelSums(*prods, num_sums=num_top_mix)]) else: sums_lower.append([ spn.SumsLayer(*prods * num_top_mix, num_or_size_sums=num_top_mix) ]) # Generate upper-SPN root = gen.generate(*list(itertools.chain(*sums_lower)), root_name="root") # Generate random weights for the SPN with tf.name_scope("Weights"): spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0), log=log_weights) # Initialize weight of the SPN spn_init = spn.initialize_weights(root) # Testing validity of the SPN self.assertTrue(root.is_valid()) # Generate value ops of the SPN spn_v = root.get_value() spn_v_log = root.get_log_value() # Generate path ops of the SPN spn_mpe_path_gen = spn.MPEPath(log=False) spn_mpe_path_gen_log = spn.MPEPath(log=True) spn_mpe_path_gen.get_mpe_path(root) spn_mpe_path_gen_log.get_mpe_path(root) spn_path = [spn_mpe_path_gen.counts[inp] for inp in inputs] spn_path_log = [spn_mpe_path_gen_log.counts[inp] for inp in inputs] # Collect all weight nodes in the SPN spn_weight_nodes = [] def fun(node): if node.is_param: spn_weight_nodes.append(node) spn.traverse_graph(root, fun=fun) # Create a session with self.test_session() as sess: # Initializing weights sess.run(sub_spn_init) sess.run(spn_init) # Generate input feed feed = np.array( list( itertools.product(range(num_vals), repeat=(num_inputs * num_vars)))) batch_size = feed.shape[0] feed_dict = {} for inp, f in zip(inputs, np.split(feed, num_inputs, axis=1)): feed_dict[inp] = f # Compute all values and paths of sub-SPN sub_spn_out = sess.run(sub_spn_v, feed_dict=feed_dict) sub_spn_out_log = sess.run(tf.exp(sub_spn_v_log), feed_dict=feed_dict) sub_spn_out_path = sess.run(sub_spn_path, feed_dict=feed_dict) sub_spn_out_path_log = sess.run(sub_spn_path_log, feed_dict=feed_dict) # Compute all values and paths of the complete SPN spn_out = sess.run(spn_v, feed_dict=feed_dict) spn_out_log = sess.run(tf.exp(spn_v_log), feed_dict=feed_dict) spn_out_path = sess.run(spn_path, feed_dict=feed_dict) spn_out_path_log = sess.run(spn_path_log, feed_dict=feed_dict) # Test if partition function of the sub-SPN and of the # complete SPN is 1.0 self.assertAlmostEqual(sub_spn_out.sum(), 1.0, places=6) self.assertAlmostEqual(sub_spn_out_log.sum(), 1.0, places=6) self.assertAlmostEqual(spn_out.sum(), 1.0, places=6) self.assertAlmostEqual(spn_out_log.sum(), 1.0, places=6) # Test if the sum of counts for each value of each variable # (6 variables, with 2 values each) = batch-size / num-vals self.assertEqual( np.sum(np.hstack(sub_spn_out_path), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(sub_spn_out_path_log), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(spn_out_path), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals) self.assertEqual( np.sum(np.hstack(spn_out_path_log), axis=0).tolist(), [batch_size // num_vals] * num_inputs * num_vars * num_vals)
input_dist=input_dist) # Generate a dense SPN for each class class_roots = [ dense_generator.generate(leaf_indicators) for _ in range(num_classes) ] # Connect sub-SPNs to a root root = spn.Sum(*class_roots, name="RootSum") root = spn.convert_to_layer_nodes(root) # Add an IVs node to the root as a latent class variable class_indicators = root.generate_latent_indicators() # Generate the weights for the SPN rooted at `root` spn.generate_weights(root) print("SPN depth: {}".format(root.get_depth())) print("Number of products layers: {}".format( root.get_num_nodes(node_type=spn.ProductsLayer))) print("Number of sums layers: {}".format( root.get_num_nodes(node_type=spn.SumsLayer))) # Op for initializing all weights weight_init_op = spn.initialize_weights(root) # Op for getting the log probability of the root root_log_prob = root.get_log_value(inference_type=inference_type) # Helper for constructing EM learning ops em_learning = spn.GDLearning(initial_accum_value=initial_accum_value, root=root,
def train(args): reporter = ExperimentLogger(args.name, args.log_base_path) reporter.write_hyperparameter_dict(vars(args)) test_x, test_y, train_x, train_y, num_classes = load_data(args) num_rows, num_cols = train_x.shape[1:3] num_vars = train_x.shape[1] * train_x.shape[2] num_dims = train_x.shape[-1] train_x = np.squeeze(train_x.reshape(-1, num_vars, num_dims)) test_x = np.squeeze(test_x.reshape(-1, num_vars, num_dims)) train_augmented_iterator = ImageIterator( [train_x, train_y], batch_size=args.batch_size, shuffle=True, width_shift_range=args.width_shift_range, height_shift_range=args.height_shift_range, shear_range=args.shear_range, rotation_range=args.rotation_range, zoom_range=args.zoom_range, horizontal_flip=args.horizontal_flip, image_dims=(num_rows, num_cols)) train_iterator = DataIterator([train_x, train_y], batch_size=args.batch_size, shuffle=True) test_iterator = DataIterator([test_x, test_y], batch_size=args.eval_batch_size, shuffle=False) in_var, root = build_spn(args, num_dims, num_vars, train_x, train_y) spn.generate_weights(root, log=args.log_weights, initializer=tf.initializers.random_uniform( args.weight_init_min, args.weight_init_max)) init_weights = spn.initialize_weights(root) correct, labels_node, loss, likelihood, update_op, pred_op, reg_loss, loss_per_sample, \ mpe_in_var = setup_learning(args, in_var, root) # Set up the evaluation tasks def evaluate_classification(image_batch, labels_batch, epoch, step): feed_dict = {in_var: image_batch} if args.supervised: feed_dict[labels_node] = labels_batch loss_out, correct_out, likelihood_out, reg_loss_out, loss_per_sample_out = sess.run( [loss, correct, likelihood, reg_loss, loss_per_sample], feed_dict=feed_dict) return [loss_out, reg_loss_out, correct_out * 100, likelihood_out] # Set up the evaluation tasks def evaluate_likelihood(image_batch, labels_batch, epoch, step): feed_dict = {in_var: image_batch} if args.supervised: feed_dict[labels_node] = labels_batch return sess.run(likelihood, feed_dict=feed_dict) # These are default evaluation metrics to be measured at the end of each epoch metrics = ["loss", "reg_loss", "accuracy", 'likelihood'] gm_default = GroupedMetrics( reporter=reporter, reduce_fun=np.mean if not args.novelty_detection else 'roc') if args.supervised: gm_default.add_task('test_epoch.csv', fun=evaluate_classification, iterator=test_iterator, metric_names=metrics, desc="Evaluate test ", batch_size=args.eval_batch_size) gm_default.add_task('train_epoch.csv', fun=evaluate_classification, iterator=train_iterator, metric_names=metrics, desc="Evaluate train", batch_size=args.eval_batch_size, return_val=True) else: gm_default.add_task('test_epoch.csv', fun=evaluate_likelihood, iterator=test_iterator, metric_names=['likelihood'], desc="Likelihood test ", batch_size=args.eval_batch_size) gm_default.add_task('train_epoch.csv', fun=evaluate_likelihood, iterator=train_iterator, metric_names=['likelihood'], desc="Likelihood train", batch_size=args.eval_batch_size, return_val=True) if args.completion: with tf.name_scope("CompletionSummary"): truth = in_var.feed if not args.discrete else tf.placeholder( tf.float32, [None, num_vars]) completion_indices = tf.equal( in_var.feed, -1) if args.discrete else tf.logical_not( in_var.evidence) shape = (-1, num_rows, num_cols, num_dims) mosaic = impainting_mosaic( reconstruction=tf.reshape(mpe_in_var, shape), truth=tf.reshape(truth, shape), completion_indices=tf.reshape(completion_indices, shape), num_rows=4, batch_size=args.completion_batch_size, invert=args.dataset == "mnist") mosaic_summary = tf.summary.image("Completion", mosaic) def completion_left(image_batch, labels_batch, epoch, step): shape = [len(image_batch), num_rows, num_cols // 2] if np.prod(shape) > image_batch.size: shape.append(3) completion_ind = np.concatenate([np.ones(shape), np.zeros(shape)], axis=2) \ .astype(np.bool) evidence_ind = np.logical_not(completion_ind) evidence_ind = np.reshape(evidence_ind, image_batch.shape[:2]) completion_ind = np.reshape(completion_ind, image_batch.shape[:2]) return _measure_completion(completion_ind, epoch, evidence_ind, image_batch, labels_batch, step, tag='left', writer=test_writer_left) def completion_bottom(image_batch, labels_batch, epoch, step): shape = [len(image_batch), num_rows // 2, num_cols] if np.prod(shape) > image_batch.size: shape.append(3) completion_ind = np.concatenate([np.zeros(shape), np.ones(shape)], axis=1) \ .astype(np.bool) evidence_ind = np.logical_not(completion_ind) evidence_ind = np.reshape(evidence_ind, image_batch.shape[:2]) completion_ind = np.reshape(completion_ind, image_batch.shape[:2]) return _measure_completion(completion_ind, epoch, evidence_ind, image_batch, labels_batch, step, tag='bottom', writer=test_writer_bottom) def _measure_completion(completion_ind, epoch, evidence_ind, image_batch, labels_batch, step, writer, tag="comp"): if args.discrete: im = image_batch.copy() im[completion_ind] = -1 feed_dict = {in_var: im} else: feed_dict = { in_var: image_batch, in_var.evidence: evidence_ind } if args.supervised: feed_dict[labels_node] = labels_batch if step == 0: if args.discrete: feed_dict[truth] = image_batch mpe_in_var_out, mosaic_summary_out, mosaic_out = sess.run( [mpe_in_var, mosaic_summary, mosaic], feed_dict=feed_dict) writer.add_summary(mosaic_summary_out, epoch) reporter.write_image( np.squeeze(mosaic_out, axis=0), 'completion/epoch_{}_{}.png'.format(epoch, tag)) else: mpe_in_var_out = sess.run(mpe_in_var, feed_dict=feed_dict) if not args.normalize_data: mpe_in_var_out *= 255 orig = image_batch.copy() * 255 else: orig = image_batch hamming = np.equal(orig, mpe_in_var_out)[completion_ind] max_fluctuation = args.num_vals**2 if args.discrete else 1.0 l2 = np.square(orig - mpe_in_var_out)[completion_ind] l1 = np.abs(orig - mpe_in_var_out)[completion_ind] hamming = np.mean(hamming.reshape(len(orig), -1), axis=-1) l2 = np.mean(l2.reshape(len(orig), -1), axis=-1) l1 = np.mean(l1.reshape(len(orig), -1), axis=-1) psnr = 10 * np.log10(max_fluctuation / l2) tv = tv_norm( mpe_in_var_out.reshape(-1, num_rows, num_cols, num_dims), completion_ind.reshape(-1, num_rows, num_cols, num_dims)) return l1, l2, hamming, psnr, tv gm_default.add_task( "test_epoch.csv", fun=completion_bottom, iterator=test_iterator, metric_names=['l1_b', 'l2_b', 'hamming_b', 'psnr_b', 'tv_b'], desc='Completion bottom', batch_size=args.completion_batch_size) gm_default.add_task( "test_epoch.csv", fun=completion_left, iterator=test_iterator, metric_names=['l1_l', 'l2_l', 'hamming_l', 'psnr_l', 'tv_l'], desc='Completion left ', batch_size=args.completion_batch_size) # Reporting total number of trainable variables trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) total_trainable_parameters = sum( [np.prod(v.shape.as_list()) for v in trainable_vars]) reporter.write_line('num_trainable_parameters.csv', total_trainable_parameters=total_trainable_parameters) logger.info( "Num trainable parameters = {}".format(total_trainable_parameters)) # Remember five last metrics for determining stop criterion progress_history = deque(maxlen=5) progress_metric = 'loss' if args.supervised else 'likelihood' with tf.Session() as sess: train_writer, test_writer = initialize_graph(init_weights, reporter, sess) test_writer_left = reporter.tfwriter('test', 'completion', 'left', exist_ok=True) test_writer_bottom = reporter.tfwriter('test', 'completion', 'bottom', exist_ok=True) progress_prev = gm_default.evaluate_one_epoch(0)[progress_metric] progress_history.append(progress_prev) for epoch in range(args.num_epochs): # Train, nothing more nothing less for image_batch, labels_batch in train_augmented_iterator.iter_epoch( "Train"): if args.input_dropout: dropout_mask = np.less( np.random.rand(*image_batch.shape[:2]), args.input_dropout) if args.discrete: image_batch_copy = image_batch.copy() image_batch_copy[dropout_mask] = -1 feed_dict = {in_var: image_batch_copy} else: feed_dict = { in_var: image_batch, in_var.evidence: np.logical_not(dropout_mask) } else: feed_dict = {in_var: image_batch} if args.supervised: feed_dict[labels_node] = labels_batch sess.run(update_op, feed_dict=feed_dict) # Check stopping criterion progress_epoch = gm_default.evaluate_one_epoch(epoch + 1)[progress_metric] progress_history.append(progress_epoch) if len(progress_history) == 5 and np.std(progress_history) < args.stop_epsilon or \ np.isnan(progress_epoch) or progress_epoch == 0.0: print("Stopping criterion reached!") break # Store locations and scales if not args.discrete: loc, scale = sess.run([in_var.loc_variable, in_var.scale_variable]) reporter.write_numpy(loc, "dist/loc") reporter.write_numpy(scale, "dist/scale") print("Locations\n", np.unique(loc)) print("\nScales:\n", np.unique(scale))
def test_compare_manual_conv(self, log_weights, inference_type): spn.conf.argmax_zero = True spatial_dims = [4, 4] nrows, ncols = spatial_dims num_vals = 4 batch_size = 128 num_vars = spatial_dims[0] * spatial_dims[1] indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals) num_sums = 32 weights = spn.Weights(num_weights=num_vals, num_sums=num_sums, initializer=tf.initializers.random_uniform(), log=log_weights) parsums = [] for row in range(nrows): for col in range(ncols): indices = list( range(row * (ncols * num_vals) + col * num_vals, row * (ncols * num_vals) + (col + 1) * num_vals)) parsums.append( spn.ParallelSums((indicator_leaf, indices), num_sums=num_sums, weights=weights)) convsum = spn.ConvSums(indicator_leaf, num_channels=num_sums, weights=weights, spatial_dim_sizes=spatial_dims) dense_gen = spn.DenseSPNGenerator( num_decomps=1, num_mixtures=2, num_subsets=2, input_dist=spn.DenseSPNGenerator.InputDist.RAW, node_type=spn.DenseSPNGenerator.NodeType.BLOCK) rnd = random.Random(1234) rnd_state = rnd.getstate() conv_root = dense_gen.generate(convsum, rnd=rnd) rnd.setstate(rnd_state) parsum_concat = spn.Concat(*parsums, name="ParSumConcat") parsum_root = dense_gen.generate(parsum_concat, rnd=rnd) self.assertTrue(conv_root.is_valid()) self.assertTrue(parsum_root.is_valid()) self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope()) spn.generate_weights(conv_root, log=log_weights) spn.generate_weights(parsum_root, log=log_weights) convsum.set_weights(weights) [p.set_weights(weights) for p in parsums] init_conv = spn.initialize_weights(conv_root) init_parsum = spn.initialize_weights(parsum_root) path_conv = spn.MPEPath(value_inference_type=inference_type) path_conv.get_mpe_path(conv_root) path_parsum = spn.MPEPath(value_inference_type=inference_type) path_parsum.get_mpe_path(parsum_root) indicator_leaf_count_parsum = path_parsum.counts[indicator_leaf] indicator_leaf_count_convsum = path_conv.counts[indicator_leaf] weight_counts_parsum = path_parsum.counts[weights] weight_counts_conv = path_conv.counts[weights] root_val_parsum = path_parsum.value.values[parsum_root] root_val_conv = path_conv.value.values[conv_root] parsum_counts = path_parsum.counts[parsum_concat] conv_counts = path_conv.counts[convsum] indicator_feed = np.random.randint(2, size=batch_size * num_vars)\ .reshape((batch_size, num_vars)) with tf.Session() as sess: sess.run([init_conv, init_parsum]) indicator_counts_conv_out, indicator_count_parsum_out = sess.run( [indicator_leaf_count_convsum, indicator_leaf_count_parsum], feed_dict={indicator_leaf: indicator_feed}) root_conv_value_out, root_parsum_value_out = sess.run( [root_val_conv, root_val_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_counts_conv_out, weight_counts_parsum_out = sess.run( [weight_counts_conv, weight_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_value_conv_out, weight_value_parsum_out = sess.run([ convsum.weights.node.variable, parsums[0].weights.node.variable ]) parsum_counts_out, conv_counts_out = sess.run( [parsum_counts, conv_counts], feed_dict={indicator_leaf: indicator_feed}) parsum_concat_val, convsum_val = sess.run( [ path_parsum.value.values[parsum_concat], path_conv.value.values[convsum] ], feed_dict={indicator_leaf: indicator_feed}) self.assertTrue(np.all(np.less_equal(convsum_val, 0.0))) self.assertTrue(np.all(np.less_equal(parsum_concat_val, 0.0))) self.assertAllClose(weight_value_conv_out, weight_value_parsum_out) self.assertAllClose(root_conv_value_out, root_parsum_value_out) self.assertAllClose(indicator_counts_conv_out, indicator_count_parsum_out) self.assertAllClose(parsum_counts_out, conv_counts_out) self.assertAllClose(weight_counts_conv_out, weight_counts_parsum_out)
def test_compare_manual_conv(self, log_weights, inference_type): spn.conf.argmax_zero = True grid_dims = [2, 2] nrows, ncols = grid_dims num_vals = 4 batch_size = 256 num_vars = grid_dims[0] * grid_dims[1] indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals) num_sums = 32 weights = spn.Weights(num_weights=num_vals, num_sums=num_sums * num_vars, initializer=tf.initializers.random_uniform(), log=log_weights) weights_per_cell = tf.split(weights.variable, num_or_size_splits=num_vars) parsums = [] for row in range(nrows): for col in range(ncols): indices = list( range(row * (ncols * num_vals) + col * num_vals, row * (ncols * num_vals) + (col + 1) * num_vals)) parsums.append( spn.ParallelSums((indicator_leaf, indices), num_sums=num_sums)) parsum_concat = spn.Concat(*parsums, name="ParSumConcat") convsum = spn.LocalSums(indicator_leaf, num_channels=num_sums, weights=weights, spatial_dim_sizes=grid_dims) prod00_conv = spn.PermuteProducts( (convsum, list(range(num_sums))), (convsum, list(range(num_sums, num_sums * 2))), name="Prod00") prod01_conv = spn.PermuteProducts( (convsum, list(range(num_sums * 2, num_sums * 3))), (convsum, list(range(num_sums * 3, num_sums * 4))), name="Prod01") sum00_conv = spn.ParallelSums(prod00_conv, num_sums=2) sum01_conv = spn.ParallelSums(prod01_conv, num_sums=2) prod10_conv = spn.PermuteProducts(sum00_conv, sum01_conv, name="Prod10") conv_root = spn.Sum(prod10_conv) prod00_pars = spn.PermuteProducts( (parsum_concat, list(range(num_sums))), (parsum_concat, list(range(num_sums, num_sums * 2)))) prod01_pars = spn.PermuteProducts( (parsum_concat, list(range(num_sums * 2, num_sums * 3))), (parsum_concat, list(range(num_sums * 3, num_sums * 4)))) sum00_pars = spn.ParallelSums(prod00_pars, num_sums=2) sum01_pars = spn.ParallelSums(prod01_pars, num_sums=2) prod10_pars = spn.PermuteProducts(sum00_pars, sum01_pars) parsum_root = spn.Sum(prod10_pars) node_pairs = [(sum00_conv, sum00_pars), (sum01_conv, sum01_pars), (conv_root, parsum_root)] self.assertTrue(conv_root.is_valid()) self.assertTrue(parsum_root.is_valid()) self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope()) spn.generate_weights(conv_root, log=log_weights, initializer=tf.initializers.random_uniform()) spn.generate_weights(parsum_root, log=log_weights, initializer=tf.initializers.random_uniform()) convsum.set_weights(weights) copy_weight_ops = [] parsum_weight_nodes = [] for p, w in zip(parsums, weights_per_cell): copy_weight_ops.append(tf.assign(p.weights.node.variable, w)) parsum_weight_nodes.append(p.weights.node) for wc, wp in node_pairs: copy_weight_ops.append( tf.assign(wp.weights.node.variable, wc.weights.node.variable)) copy_weights_op = tf.group(*copy_weight_ops) init_conv = spn.initialize_weights(conv_root) init_parsum = spn.initialize_weights(parsum_root) path_conv = spn.MPEPath(value_inference_type=inference_type) path_conv.get_mpe_path(conv_root) path_parsum = spn.MPEPath(value_inference_type=inference_type) path_parsum.get_mpe_path(parsum_root) indicator_counts_parsum = path_parsum.counts[indicator_leaf] indicator_counts_convsum = path_conv.counts[indicator_leaf] weight_counts_parsum = tf.concat( [path_parsum.counts[w] for w in parsum_weight_nodes], axis=1) weight_counts_conv = path_conv.counts[weights] weight_parsum_concat = tf.concat( [w.variable for w in parsum_weight_nodes], axis=0) root_val_parsum = parsum_root.get_log_value( ) #path_parsum.value.values[parsum_root] root_val_conv = conv_root.get_log_value( ) #path_conv.value.values[conv_root] parsum_counts = path_parsum.counts[parsum_concat] conv_counts = path_conv.counts[convsum] indicator_feed = np.random.randint(-1, 2, size=batch_size * num_vars)\ .reshape((batch_size, num_vars)) with tf.Session() as sess: sess.run([init_conv, init_parsum]) sess.run(copy_weights_op) indicator_counts_conv_out, indicator_counts_parsum_out = sess.run( [indicator_counts_convsum, indicator_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) root_conv_value_out, root_parsum_value_out = sess.run( [root_val_conv, root_val_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_counts_conv_out, weight_counts_parsum_out = sess.run( [weight_counts_conv, weight_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_value_conv_out, weight_value_parsum_out = sess.run( [convsum.weights.node.variable, weight_parsum_concat]) parsum_counts_out, conv_counts_out = sess.run( [parsum_counts, conv_counts], feed_dict={indicator_leaf: indicator_feed}) parsum_concat_val, convsum_val = sess.run( [ path_parsum.value.values[parsum_concat], path_conv.value.values[convsum] ], feed_dict={indicator_leaf: indicator_feed}) self.assertAllClose(convsum_val, parsum_concat_val) self.assertAllClose(weight_value_conv_out, weight_value_parsum_out) self.assertAllClose(root_conv_value_out, root_parsum_value_out) self.assertAllEqual(indicator_counts_conv_out, indicator_counts_parsum_out) self.assertAllEqual(parsum_counts_out, conv_counts_out) self.assertAllEqual(weight_counts_conv_out, weight_counts_parsum_out)
def test_compute_dense_gen_two_spatial_decomps_v2(self, node_type, input_dist): input_channels = 2 grid_dims = [32, 32] num_vars = grid_dims[0] * grid_dims[1] vars = spn.IndicatorLeaf(num_vars=num_vars, num_vals=input_channels) convert_after = False if input_dist == spn.DenseSPNGenerator.InputDist.RAW and \ node_type in [spn.DenseSPNGenerator.NodeType.BLOCK, spn.DenseSPNGenerator.NodeType.LAYER]: node_type = spn.DenseSPNGenerator.NodeType.SINGLE convert_after = True # First decomposition convprod_dilate0 = spn.ConvProducts(vars, spatial_dim_sizes=grid_dims, num_channels=16, padding='valid', dilation_rate=2, strides=1, kernel_size=2) convprod_dilate1 = spn.ConvProducts(convprod_dilate0, spatial_dim_sizes=[30, 30], num_channels=512, padding='valid', dilation_rate=1, strides=4, kernel_size=2) convsum_dilate = spn.ConvSums(convprod_dilate1, num_channels=2, spatial_dim_sizes=[8, 8]) # Second decomposition convprod_stride0 = spn.ConvProducts(vars, spatial_dim_sizes=grid_dims, num_channels=16, padding='valid', dilation_rate=1, strides=2, kernel_size=2) convprod_stride1 = spn.ConvProducts(convprod_stride0, spatial_dim_sizes=[16, 16], num_channels=512, padding='valid', dilation_rate=1, strides=2, kernel_size=2) convsum_stride = spn.ConvSums(convprod_stride1, num_channels=2, spatial_dim_sizes=[8, 8]) # First decomposition level 2 convprod_dilate0_l2 = spn.ConvProducts(convsum_stride, convsum_dilate, spatial_dim_sizes=[8, 8], num_channels=512, padding='valid', dilation_rate=2, strides=1, kernel_size=2) convprod_dilate1_l2 = spn.ConvProducts(convprod_dilate0_l2, spatial_dim_sizes=[6, 6], num_channels=512, padding='valid', dilation_rate=1, kernel_size=2, strides=4) convsum_dilate_l2 = spn.ConvSums(convprod_dilate1_l2, num_channels=2, spatial_dim_sizes=[4, 4]) # Second decomposition level 2 convprod_stride0_l2 = spn.ConvProducts(convsum_stride, convsum_dilate, spatial_dim_sizes=[8, 8], num_channels=512, padding='valid', dilation_rate=1, strides=2, kernel_size=2) convprod_stride1_l2 = spn.ConvProducts(convprod_stride0_l2, spatial_dim_sizes=[4, 4], num_channels=512, padding='valid', dilation_rate=1, strides=2, kernel_size=2) convsum_stride_l2 = spn.ConvSums(convprod_stride1_l2, num_channels=2, spatial_dim_sizes=[4, 4]) dense_gen = spn.DenseSPNGenerator(num_mixtures=2, num_decomps=1, num_subsets=2, node_type=node_type, input_dist=input_dist) root = dense_gen.generate(convsum_stride_l2, convsum_dilate_l2) if convert_after: root = dense_gen.convert_to_layer_nodes(root) # Assert valid self.assertTrue(root.is_valid()) # Setup the remaining Ops spn.generate_weights(root) init = spn.initialize_weights(root) value_op = tf.squeeze(root.get_log_value()) with self.test_session() as sess: sess.run(init) value_out = sess.run( value_op, {vars: -np.ones((1, num_vars), dtype=np.int32)}) self.assertAllClose(value_out, 0.0)