示例#1
0
    def test_compute_value(self):
        indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=2 * 2)
        values = [[0, 1, 1, 0], [-1, -1, -1, 0]]
        weights = spn.Weights(initializer=tf.initializers.constant([[0.2, 0.8],
                                                                    [0.6,
                                                                     0.4]]),
                              num_sums=2,
                              num_weights=2)
        s = ConvSums(indicator_leaf,
                     spatial_dim_sizes=[2, 2],
                     num_channels=2,
                     weights=weights)

        val = s.get_value(inference_type=spn.InferenceType.MARGINAL)

        with self.test_session() as sess:
            sess.run(weights.initialize())
            out = sess.run(val, {indicator_leaf: values})

            # 0    0 |  1    1 |  1    1  | 0   0
        self.assertAllClose(
            out,
            [
                [0.2, 0.6, 0.8, 0.4, 0.8, 0.4, 0.2, 0.6],
                # 1   0  | 1     0 | 1     0  | 0   0
                [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.6]
            ])
示例#2
0
    def test_compute_scope(self):
        indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=2 * 2)
        weights = spn.Weights(initializer=tf.initializers.constant([[0.2, 0.8],
                                                                    [0.6,
                                                                     0.4]]),
                              num_sums=2,
                              num_weights=2)
        s = ConvSums(indicator_leaf,
                     spatial_dim_sizes=[2, 2],
                     num_channels=2,
                     weights=weights)

        scope = s._compute_scope(None, None, indicator_leaf._compute_scope())

        target_scope = [spn.Scope(indicator_leaf, 0)] * 2 + \
                       [spn.Scope(indicator_leaf, 1)] * 2 + \
                       [spn.Scope(indicator_leaf, 2)] * 2 + \
                       [spn.Scope(indicator_leaf, 3)] * 2
        self.assertAllEqual(scope, target_scope)
示例#3
0
    def test_single_initialization(self):
        """Single weights node initialization"""
        # Single sum
        w1 = spn.Weights(tf.initializers.constant(3), num_weights=2)
        w2 = spn.Weights(tf.initializers.constant(0.3), num_weights=4)
        w3 = spn.Weights(tf.initializers.constant([0.4, 0.4, 1.2]),
                         num_weights=3)
        # Multi sums
        w4 = spn.Weights(tf.initializers.constant(3),
                         num_weights=2,
                         num_sums=2)
        w5 = spn.Weights(tf.initializers.constant(0.3),
                         num_weights=4,
                         num_sums=3)
        w6 = spn.Weights(tf.initializers.random_uniform(0.0, 1.0),
                         num_weights=1,
                         num_sums=4)
        init1 = w1.initialize()
        init2 = w2.initialize()
        init3 = w3.initialize()
        init4 = w4.initialize()
        init5 = w5.initialize()
        init6 = w6.initialize()
        with self.test_session() as sess:
            sess.run([init1, init2, init3, init4, init5, init6])
            val1 = sess.run(w1.get_value())
            val2 = sess.run(w2.get_value())
            val3 = sess.run(w3.get_value())
            val4 = sess.run(w4.get_value())
            val5 = sess.run(w5.get_value())
            val6 = sess.run(w6.get_value())
            val1_log = sess.run(tf.exp(w1.get_log_value()))
            val2_log = sess.run(tf.exp(w2.get_log_value()))
            val3_log = sess.run(tf.exp(w3.get_log_value()))
            val4_log = sess.run(tf.exp(w4.get_log_value()))
            val5_log = sess.run(tf.exp(w5.get_log_value()))
            val6_log = sess.run(tf.exp(w6.get_log_value()))

        self.assertEqual(val1.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val2.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val3.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val4.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val5.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val6.dtype, spn.conf.dtype.as_numpy_dtype())
        np.testing.assert_array_almost_equal(val1, [[0.5, 0.5]])
        np.testing.assert_array_almost_equal(val2, [[0.25, 0.25, 0.25, 0.25]])
        np.testing.assert_array_almost_equal(val3, [[0.2, 0.2, 0.6]])
        np.testing.assert_array_almost_equal(val4, [[0.5, 0.5], [0.5, 0.5]])
        np.testing.assert_array_almost_equal(
            val5, [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25],
                   [0.25, 0.25, 0.25, 0.25]])
        np.testing.assert_array_almost_equal(val6,
                                             [[1.0], [1.0], [1.0], [1.0]])
        self.assertEqual(val1_log.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val2_log.dtype, spn.conf.dtype.as_numpy_dtype())
        self.assertEqual(val3_log.dtype, spn.conf.dtype.as_numpy_dtype())
        np.testing.assert_array_almost_equal(val1_log, [[0.5, 0.5]])
        np.testing.assert_array_almost_equal(val2_log,
                                             [[0.25, 0.25, 0.25, 0.25]])
        np.testing.assert_array_almost_equal(val3_log, [[0.2, 0.2, 0.6]])
        np.testing.assert_array_almost_equal(val4_log,
                                             [[0.5, 0.5], [0.5, 0.5]])
        np.testing.assert_array_almost_equal(
            val5_log, [[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25],
                       [0.25, 0.25, 0.25, 0.25]])
        np.testing.assert_array_almost_equal(val6_log,
                                             [[1.0], [1.0], [1.0], [1.0]])
示例#4
0
    def test_compare_manual_conv(self, log_weights, inference_type):
        spn.conf.argmax_zero = True
        spatial_dims = [4, 4]
        nrows, ncols = spatial_dims
        num_vals = 4
        batch_size = 128
        num_vars = spatial_dims[0] * spatial_dims[1]
        indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars,
                                           num_vals=num_vals)
        num_sums = 32
        weights = spn.Weights(num_weights=num_vals,
                              num_sums=num_sums,
                              initializer=tf.initializers.random_uniform(),
                              log=log_weights)

        parsums = []
        for row in range(nrows):
            for col in range(ncols):
                indices = list(
                    range(row * (ncols * num_vals) + col * num_vals,
                          row * (ncols * num_vals) + (col + 1) * num_vals))
                parsums.append(
                    spn.ParallelSums((indicator_leaf, indices),
                                     num_sums=num_sums,
                                     weights=weights))

        convsum = spn.ConvSums(indicator_leaf,
                               num_channels=num_sums,
                               weights=weights,
                               spatial_dim_sizes=spatial_dims)

        dense_gen = spn.DenseSPNGenerator(
            num_decomps=1,
            num_mixtures=2,
            num_subsets=2,
            input_dist=spn.DenseSPNGenerator.InputDist.RAW,
            node_type=spn.DenseSPNGenerator.NodeType.BLOCK)

        rnd = random.Random(1234)
        rnd_state = rnd.getstate()
        conv_root = dense_gen.generate(convsum, rnd=rnd)
        rnd.setstate(rnd_state)

        parsum_concat = spn.Concat(*parsums, name="ParSumConcat")
        parsum_root = dense_gen.generate(parsum_concat, rnd=rnd)

        self.assertTrue(conv_root.is_valid())
        self.assertTrue(parsum_root.is_valid())

        self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope())

        spn.generate_weights(conv_root, log=log_weights)
        spn.generate_weights(parsum_root, log=log_weights)

        convsum.set_weights(weights)
        [p.set_weights(weights) for p in parsums]

        init_conv = spn.initialize_weights(conv_root)
        init_parsum = spn.initialize_weights(parsum_root)

        path_conv = spn.MPEPath(value_inference_type=inference_type)
        path_conv.get_mpe_path(conv_root)

        path_parsum = spn.MPEPath(value_inference_type=inference_type)
        path_parsum.get_mpe_path(parsum_root)

        indicator_leaf_count_parsum = path_parsum.counts[indicator_leaf]
        indicator_leaf_count_convsum = path_conv.counts[indicator_leaf]

        weight_counts_parsum = path_parsum.counts[weights]
        weight_counts_conv = path_conv.counts[weights]

        root_val_parsum = path_parsum.value.values[parsum_root]
        root_val_conv = path_conv.value.values[conv_root]

        parsum_counts = path_parsum.counts[parsum_concat]
        conv_counts = path_conv.counts[convsum]

        indicator_feed = np.random.randint(2, size=batch_size * num_vars)\
            .reshape((batch_size, num_vars))
        with tf.Session() as sess:
            sess.run([init_conv, init_parsum])
            indicator_counts_conv_out, indicator_count_parsum_out = sess.run(
                [indicator_leaf_count_convsum, indicator_leaf_count_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            root_conv_value_out, root_parsum_value_out = sess.run(
                [root_val_conv, root_val_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_counts_conv_out, weight_counts_parsum_out = sess.run(
                [weight_counts_conv, weight_counts_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_value_conv_out, weight_value_parsum_out = sess.run([
                convsum.weights.node.variable, parsums[0].weights.node.variable
            ])

            parsum_counts_out, conv_counts_out = sess.run(
                [parsum_counts, conv_counts],
                feed_dict={indicator_leaf: indicator_feed})

            parsum_concat_val, convsum_val = sess.run(
                [
                    path_parsum.value.values[parsum_concat],
                    path_conv.value.values[convsum]
                ],
                feed_dict={indicator_leaf: indicator_feed})

        self.assertTrue(np.all(np.less_equal(convsum_val, 0.0)))
        self.assertTrue(np.all(np.less_equal(parsum_concat_val, 0.0)))
        self.assertAllClose(weight_value_conv_out, weight_value_parsum_out)
        self.assertAllClose(root_conv_value_out, root_parsum_value_out)
        self.assertAllClose(indicator_counts_conv_out,
                            indicator_count_parsum_out)
        self.assertAllClose(parsum_counts_out, conv_counts_out)
        self.assertAllClose(weight_counts_conv_out, weight_counts_parsum_out)
示例#5
0
    def test_compare_manual_conv(self, log_weights, inference_type):
        spn.conf.argmax_zero = True
        grid_dims = [2, 2]
        nrows, ncols = grid_dims
        num_vals = 4
        batch_size = 256
        num_vars = grid_dims[0] * grid_dims[1]
        indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars,
                                           num_vals=num_vals)
        num_sums = 32
        weights = spn.Weights(num_weights=num_vals,
                              num_sums=num_sums * num_vars,
                              initializer=tf.initializers.random_uniform(),
                              log=log_weights)

        weights_per_cell = tf.split(weights.variable,
                                    num_or_size_splits=num_vars)

        parsums = []
        for row in range(nrows):
            for col in range(ncols):
                indices = list(
                    range(row * (ncols * num_vals) + col * num_vals,
                          row * (ncols * num_vals) + (col + 1) * num_vals))
                parsums.append(
                    spn.ParallelSums((indicator_leaf, indices),
                                     num_sums=num_sums))

        parsum_concat = spn.Concat(*parsums, name="ParSumConcat")
        convsum = spn.LocalSums(indicator_leaf,
                                num_channels=num_sums,
                                weights=weights,
                                spatial_dim_sizes=grid_dims)

        prod00_conv = spn.PermuteProducts(
            (convsum, list(range(num_sums))),
            (convsum, list(range(num_sums, num_sums * 2))),
            name="Prod00")
        prod01_conv = spn.PermuteProducts(
            (convsum, list(range(num_sums * 2, num_sums * 3))),
            (convsum, list(range(num_sums * 3, num_sums * 4))),
            name="Prod01")
        sum00_conv = spn.ParallelSums(prod00_conv, num_sums=2)
        sum01_conv = spn.ParallelSums(prod01_conv, num_sums=2)

        prod10_conv = spn.PermuteProducts(sum00_conv,
                                          sum01_conv,
                                          name="Prod10")

        conv_root = spn.Sum(prod10_conv)

        prod00_pars = spn.PermuteProducts(
            (parsum_concat, list(range(num_sums))),
            (parsum_concat, list(range(num_sums, num_sums * 2))))
        prod01_pars = spn.PermuteProducts(
            (parsum_concat, list(range(num_sums * 2, num_sums * 3))),
            (parsum_concat, list(range(num_sums * 3, num_sums * 4))))

        sum00_pars = spn.ParallelSums(prod00_pars, num_sums=2)
        sum01_pars = spn.ParallelSums(prod01_pars, num_sums=2)

        prod10_pars = spn.PermuteProducts(sum00_pars, sum01_pars)

        parsum_root = spn.Sum(prod10_pars)

        node_pairs = [(sum00_conv, sum00_pars), (sum01_conv, sum01_pars),
                      (conv_root, parsum_root)]

        self.assertTrue(conv_root.is_valid())
        self.assertTrue(parsum_root.is_valid())

        self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope())

        spn.generate_weights(conv_root,
                             log=log_weights,
                             initializer=tf.initializers.random_uniform())
        spn.generate_weights(parsum_root,
                             log=log_weights,
                             initializer=tf.initializers.random_uniform())

        convsum.set_weights(weights)
        copy_weight_ops = []
        parsum_weight_nodes = []
        for p, w in zip(parsums, weights_per_cell):
            copy_weight_ops.append(tf.assign(p.weights.node.variable, w))
            parsum_weight_nodes.append(p.weights.node)

        for wc, wp in node_pairs:
            copy_weight_ops.append(
                tf.assign(wp.weights.node.variable, wc.weights.node.variable))

        copy_weights_op = tf.group(*copy_weight_ops)

        init_conv = spn.initialize_weights(conv_root)
        init_parsum = spn.initialize_weights(parsum_root)

        path_conv = spn.MPEPath(value_inference_type=inference_type)
        path_conv.get_mpe_path(conv_root)

        path_parsum = spn.MPEPath(value_inference_type=inference_type)
        path_parsum.get_mpe_path(parsum_root)

        indicator_counts_parsum = path_parsum.counts[indicator_leaf]
        indicator_counts_convsum = path_conv.counts[indicator_leaf]

        weight_counts_parsum = tf.concat(
            [path_parsum.counts[w] for w in parsum_weight_nodes], axis=1)
        weight_counts_conv = path_conv.counts[weights]

        weight_parsum_concat = tf.concat(
            [w.variable for w in parsum_weight_nodes], axis=0)

        root_val_parsum = parsum_root.get_log_value(
        )  #path_parsum.value.values[parsum_root]
        root_val_conv = conv_root.get_log_value(
        )  #path_conv.value.values[conv_root]

        parsum_counts = path_parsum.counts[parsum_concat]
        conv_counts = path_conv.counts[convsum]

        indicator_feed = np.random.randint(-1, 2, size=batch_size * num_vars)\
            .reshape((batch_size, num_vars))
        with tf.Session() as sess:
            sess.run([init_conv, init_parsum])
            sess.run(copy_weights_op)
            indicator_counts_conv_out, indicator_counts_parsum_out = sess.run(
                [indicator_counts_convsum, indicator_counts_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            root_conv_value_out, root_parsum_value_out = sess.run(
                [root_val_conv, root_val_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_counts_conv_out, weight_counts_parsum_out = sess.run(
                [weight_counts_conv, weight_counts_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_value_conv_out, weight_value_parsum_out = sess.run(
                [convsum.weights.node.variable, weight_parsum_concat])

            parsum_counts_out, conv_counts_out = sess.run(
                [parsum_counts, conv_counts],
                feed_dict={indicator_leaf: indicator_feed})

            parsum_concat_val, convsum_val = sess.run(
                [
                    path_parsum.value.values[parsum_concat],
                    path_conv.value.values[convsum]
                ],
                feed_dict={indicator_leaf: indicator_feed})

        self.assertAllClose(convsum_val, parsum_concat_val)
        self.assertAllClose(weight_value_conv_out, weight_value_parsum_out)
        self.assertAllClose(root_conv_value_out, root_parsum_value_out)
        self.assertAllEqual(indicator_counts_conv_out,
                            indicator_counts_parsum_out)
        self.assertAllEqual(parsum_counts_out, conv_counts_out)
        self.assertAllEqual(weight_counts_conv_out, weight_counts_parsum_out)