def test_compute_value(self): indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=2 * 2) values = [[0, 1, 1, 0], [-1, -1, -1, 0]] weights = spn.Weights(initializer=tf.initializers.constant([[0.2, 0.8], [0.6, 0.4]]), num_sums=2, num_weights=2) s = ConvSums(indicator_leaf, spatial_dim_sizes=[2, 2], num_channels=2, weights=weights) val = s.get_value(inference_type=spn.InferenceType.MARGINAL) with self.test_session() as sess: sess.run(weights.initialize()) out = sess.run(val, {indicator_leaf: values}) # 0 0 | 1 1 | 1 1 | 0 0 self.assertAllClose( out, [ [0.2, 0.6, 0.8, 0.4, 0.8, 0.4, 0.2, 0.6], # 1 0 | 1 0 | 1 0 | 0 0 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.6] ])
def test_compute_scope(self): indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=2 * 2) weights = spn.Weights(initializer=tf.initializers.constant([[0.2, 0.8], [0.6, 0.4]]), num_sums=2, num_weights=2) s = ConvSums(indicator_leaf, spatial_dim_sizes=[2, 2], num_channels=2, weights=weights) scope = s._compute_scope(None, None, indicator_leaf._compute_scope()) target_scope = [spn.Scope(indicator_leaf, 0)] * 2 + \ [spn.Scope(indicator_leaf, 1)] * 2 + \ [spn.Scope(indicator_leaf, 2)] * 2 + \ [spn.Scope(indicator_leaf, 3)] * 2 self.assertAllEqual(scope, target_scope)
def test_single_initialization(self): """Single weights node initialization""" # Single sum w1 = spn.Weights(tf.initializers.constant(3), num_weights=2) w2 = spn.Weights(tf.initializers.constant(0.3), num_weights=4) w3 = spn.Weights(tf.initializers.constant([0.4, 0.4, 1.2]), num_weights=3) # Multi sums w4 = spn.Weights(tf.initializers.constant(3), num_weights=2, num_sums=2) w5 = spn.Weights(tf.initializers.constant(0.3), num_weights=4, num_sums=3) w6 = spn.Weights(tf.initializers.random_uniform(0.0, 1.0), num_weights=1, num_sums=4) init1 = w1.initialize() init2 = w2.initialize() init3 = w3.initialize() init4 = w4.initialize() init5 = w5.initialize() init6 = w6.initialize() with self.test_session() as sess: sess.run([init1, init2, init3, init4, init5, init6]) val1 = sess.run(w1.get_value()) val2 = sess.run(w2.get_value()) val3 = sess.run(w3.get_value()) val4 = sess.run(w4.get_value()) val5 = sess.run(w5.get_value()) val6 = sess.run(w6.get_value()) val1_log = sess.run(tf.exp(w1.get_log_value())) val2_log = sess.run(tf.exp(w2.get_log_value())) val3_log = sess.run(tf.exp(w3.get_log_value())) val4_log = sess.run(tf.exp(w4.get_log_value())) val5_log = sess.run(tf.exp(w5.get_log_value())) val6_log = sess.run(tf.exp(w6.get_log_value())) self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val4.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val5.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val6.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1, [[0.5, 0.5]]) np.testing.assert_array_almost_equal(val2, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val3, [[0.2, 0.2, 0.6]]) np.testing.assert_array_almost_equal(val4, [[0.5, 0.5], [0.5, 0.5]]) np.testing.assert_array_almost_equal( val5, [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val6, [[1.0], [1.0], [1.0], [1.0]]) self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype()) self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype()) np.testing.assert_array_almost_equal(val1_log, [[0.5, 0.5]]) np.testing.assert_array_almost_equal(val2_log, [[0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val3_log, [[0.2, 0.2, 0.6]]) np.testing.assert_array_almost_equal(val4_log, [[0.5, 0.5], [0.5, 0.5]]) np.testing.assert_array_almost_equal( val5_log, [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]) np.testing.assert_array_almost_equal(val6_log, [[1.0], [1.0], [1.0], [1.0]])
def test_compare_manual_conv(self, log_weights, inference_type): spn.conf.argmax_zero = True spatial_dims = [4, 4] nrows, ncols = spatial_dims num_vals = 4 batch_size = 128 num_vars = spatial_dims[0] * spatial_dims[1] indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals) num_sums = 32 weights = spn.Weights(num_weights=num_vals, num_sums=num_sums, initializer=tf.initializers.random_uniform(), log=log_weights) parsums = [] for row in range(nrows): for col in range(ncols): indices = list( range(row * (ncols * num_vals) + col * num_vals, row * (ncols * num_vals) + (col + 1) * num_vals)) parsums.append( spn.ParallelSums((indicator_leaf, indices), num_sums=num_sums, weights=weights)) convsum = spn.ConvSums(indicator_leaf, num_channels=num_sums, weights=weights, spatial_dim_sizes=spatial_dims) dense_gen = spn.DenseSPNGenerator( num_decomps=1, num_mixtures=2, num_subsets=2, input_dist=spn.DenseSPNGenerator.InputDist.RAW, node_type=spn.DenseSPNGenerator.NodeType.BLOCK) rnd = random.Random(1234) rnd_state = rnd.getstate() conv_root = dense_gen.generate(convsum, rnd=rnd) rnd.setstate(rnd_state) parsum_concat = spn.Concat(*parsums, name="ParSumConcat") parsum_root = dense_gen.generate(parsum_concat, rnd=rnd) self.assertTrue(conv_root.is_valid()) self.assertTrue(parsum_root.is_valid()) self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope()) spn.generate_weights(conv_root, log=log_weights) spn.generate_weights(parsum_root, log=log_weights) convsum.set_weights(weights) [p.set_weights(weights) for p in parsums] init_conv = spn.initialize_weights(conv_root) init_parsum = spn.initialize_weights(parsum_root) path_conv = spn.MPEPath(value_inference_type=inference_type) path_conv.get_mpe_path(conv_root) path_parsum = spn.MPEPath(value_inference_type=inference_type) path_parsum.get_mpe_path(parsum_root) indicator_leaf_count_parsum = path_parsum.counts[indicator_leaf] indicator_leaf_count_convsum = path_conv.counts[indicator_leaf] weight_counts_parsum = path_parsum.counts[weights] weight_counts_conv = path_conv.counts[weights] root_val_parsum = path_parsum.value.values[parsum_root] root_val_conv = path_conv.value.values[conv_root] parsum_counts = path_parsum.counts[parsum_concat] conv_counts = path_conv.counts[convsum] indicator_feed = np.random.randint(2, size=batch_size * num_vars)\ .reshape((batch_size, num_vars)) with tf.Session() as sess: sess.run([init_conv, init_parsum]) indicator_counts_conv_out, indicator_count_parsum_out = sess.run( [indicator_leaf_count_convsum, indicator_leaf_count_parsum], feed_dict={indicator_leaf: indicator_feed}) root_conv_value_out, root_parsum_value_out = sess.run( [root_val_conv, root_val_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_counts_conv_out, weight_counts_parsum_out = sess.run( [weight_counts_conv, weight_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_value_conv_out, weight_value_parsum_out = sess.run([ convsum.weights.node.variable, parsums[0].weights.node.variable ]) parsum_counts_out, conv_counts_out = sess.run( [parsum_counts, conv_counts], feed_dict={indicator_leaf: indicator_feed}) parsum_concat_val, convsum_val = sess.run( [ path_parsum.value.values[parsum_concat], path_conv.value.values[convsum] ], feed_dict={indicator_leaf: indicator_feed}) self.assertTrue(np.all(np.less_equal(convsum_val, 0.0))) self.assertTrue(np.all(np.less_equal(parsum_concat_val, 0.0))) self.assertAllClose(weight_value_conv_out, weight_value_parsum_out) self.assertAllClose(root_conv_value_out, root_parsum_value_out) self.assertAllClose(indicator_counts_conv_out, indicator_count_parsum_out) self.assertAllClose(parsum_counts_out, conv_counts_out) self.assertAllClose(weight_counts_conv_out, weight_counts_parsum_out)
def test_compare_manual_conv(self, log_weights, inference_type): spn.conf.argmax_zero = True grid_dims = [2, 2] nrows, ncols = grid_dims num_vals = 4 batch_size = 256 num_vars = grid_dims[0] * grid_dims[1] indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars, num_vals=num_vals) num_sums = 32 weights = spn.Weights(num_weights=num_vals, num_sums=num_sums * num_vars, initializer=tf.initializers.random_uniform(), log=log_weights) weights_per_cell = tf.split(weights.variable, num_or_size_splits=num_vars) parsums = [] for row in range(nrows): for col in range(ncols): indices = list( range(row * (ncols * num_vals) + col * num_vals, row * (ncols * num_vals) + (col + 1) * num_vals)) parsums.append( spn.ParallelSums((indicator_leaf, indices), num_sums=num_sums)) parsum_concat = spn.Concat(*parsums, name="ParSumConcat") convsum = spn.LocalSums(indicator_leaf, num_channels=num_sums, weights=weights, spatial_dim_sizes=grid_dims) prod00_conv = spn.PermuteProducts( (convsum, list(range(num_sums))), (convsum, list(range(num_sums, num_sums * 2))), name="Prod00") prod01_conv = spn.PermuteProducts( (convsum, list(range(num_sums * 2, num_sums * 3))), (convsum, list(range(num_sums * 3, num_sums * 4))), name="Prod01") sum00_conv = spn.ParallelSums(prod00_conv, num_sums=2) sum01_conv = spn.ParallelSums(prod01_conv, num_sums=2) prod10_conv = spn.PermuteProducts(sum00_conv, sum01_conv, name="Prod10") conv_root = spn.Sum(prod10_conv) prod00_pars = spn.PermuteProducts( (parsum_concat, list(range(num_sums))), (parsum_concat, list(range(num_sums, num_sums * 2)))) prod01_pars = spn.PermuteProducts( (parsum_concat, list(range(num_sums * 2, num_sums * 3))), (parsum_concat, list(range(num_sums * 3, num_sums * 4)))) sum00_pars = spn.ParallelSums(prod00_pars, num_sums=2) sum01_pars = spn.ParallelSums(prod01_pars, num_sums=2) prod10_pars = spn.PermuteProducts(sum00_pars, sum01_pars) parsum_root = spn.Sum(prod10_pars) node_pairs = [(sum00_conv, sum00_pars), (sum01_conv, sum01_pars), (conv_root, parsum_root)] self.assertTrue(conv_root.is_valid()) self.assertTrue(parsum_root.is_valid()) self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope()) spn.generate_weights(conv_root, log=log_weights, initializer=tf.initializers.random_uniform()) spn.generate_weights(parsum_root, log=log_weights, initializer=tf.initializers.random_uniform()) convsum.set_weights(weights) copy_weight_ops = [] parsum_weight_nodes = [] for p, w in zip(parsums, weights_per_cell): copy_weight_ops.append(tf.assign(p.weights.node.variable, w)) parsum_weight_nodes.append(p.weights.node) for wc, wp in node_pairs: copy_weight_ops.append( tf.assign(wp.weights.node.variable, wc.weights.node.variable)) copy_weights_op = tf.group(*copy_weight_ops) init_conv = spn.initialize_weights(conv_root) init_parsum = spn.initialize_weights(parsum_root) path_conv = spn.MPEPath(value_inference_type=inference_type) path_conv.get_mpe_path(conv_root) path_parsum = spn.MPEPath(value_inference_type=inference_type) path_parsum.get_mpe_path(parsum_root) indicator_counts_parsum = path_parsum.counts[indicator_leaf] indicator_counts_convsum = path_conv.counts[indicator_leaf] weight_counts_parsum = tf.concat( [path_parsum.counts[w] for w in parsum_weight_nodes], axis=1) weight_counts_conv = path_conv.counts[weights] weight_parsum_concat = tf.concat( [w.variable for w in parsum_weight_nodes], axis=0) root_val_parsum = parsum_root.get_log_value( ) #path_parsum.value.values[parsum_root] root_val_conv = conv_root.get_log_value( ) #path_conv.value.values[conv_root] parsum_counts = path_parsum.counts[parsum_concat] conv_counts = path_conv.counts[convsum] indicator_feed = np.random.randint(-1, 2, size=batch_size * num_vars)\ .reshape((batch_size, num_vars)) with tf.Session() as sess: sess.run([init_conv, init_parsum]) sess.run(copy_weights_op) indicator_counts_conv_out, indicator_counts_parsum_out = sess.run( [indicator_counts_convsum, indicator_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) root_conv_value_out, root_parsum_value_out = sess.run( [root_val_conv, root_val_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_counts_conv_out, weight_counts_parsum_out = sess.run( [weight_counts_conv, weight_counts_parsum], feed_dict={indicator_leaf: indicator_feed}) weight_value_conv_out, weight_value_parsum_out = sess.run( [convsum.weights.node.variable, weight_parsum_concat]) parsum_counts_out, conv_counts_out = sess.run( [parsum_counts, conv_counts], feed_dict={indicator_leaf: indicator_feed}) parsum_concat_val, convsum_val = sess.run( [ path_parsum.value.values[parsum_concat], path_conv.value.values[convsum] ], feed_dict={indicator_leaf: indicator_feed}) self.assertAllClose(convsum_val, parsum_concat_val) self.assertAllClose(weight_value_conv_out, weight_value_parsum_out) self.assertAllClose(root_conv_value_out, root_parsum_value_out) self.assertAllEqual(indicator_counts_conv_out, indicator_counts_parsum_out) self.assertAllEqual(parsum_counts_out, conv_counts_out) self.assertAllEqual(weight_counts_conv_out, weight_counts_parsum_out)