def test_traverse_graph_nostop_noparams(self):
        """Traversing the whole graph excluding param nodes"""
        counter = [0]
        nodes = [None] * 10

        def fun(node):
            nodes[counter[0]] = node
            counter[0] += 1

        # Generate graph
        v1 = spn.RawLeaf(num_vars=1)
        v2 = spn.RawLeaf(num_vars=1)
        v3 = spn.RawLeaf(num_vars=1)
        s1 = spn.Sum(v1, v1, v2)  # v1 included twice
        s2 = spn.Sum(v1, v3)
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        s4 = spn.Sum(s1, v1)
        s5 = spn.Sum(s2, v3, s3)
        s6 = spn.Sum(s4, s2, s5, s4, s5)  # s4 and s5 included twice
        spn.generate_weights(s6)

        # Traverse
        spn.traverse_graph(s6, fun=fun, skip_params=True)

        # Test
        self.assertEqual(counter[0], 9)
        self.assertIs(nodes[0], s6)
        self.assertIs(nodes[1], s4)
        self.assertIs(nodes[2], s2)
        self.assertIs(nodes[3], s5)
        self.assertIs(nodes[4], s1)
        self.assertIs(nodes[5], v1)
        self.assertIs(nodes[6], v3)
        self.assertIs(nodes[7], s3)
        self.assertIs(nodes[8], v2)
Esempio n. 2
0
    def poons_multi(inputs,
                    num_vals,
                    num_mixtures,
                    num_subsets,
                    inf_type,
                    log=False,
                    output=None):

        # Build a POON-like network with multi-op nodes
        subsets = [
            spn.ParSums((inputs, list(range(i * num_vals,
                                            (i + 1) * num_vals))),
                        num_sums=num_mixtures) for i in range(num_subsets)
        ]
        products = spn.PermProducts(*subsets)
        root = spn.Sum(products, name="root")

        # Generate dense SPN and all weights in the network
        spn.generate_weights(root)

        # Generate path ops based on inf_type and log
        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [
            mpe_path_gen.counts[inp]
            for inp in (inputs if isinstance(inputs, list) else [inputs])
        ]
        return root, spn.initialize_weights(root), path_ops
    def test_compute_graph_up_noconst(self):
        """Computing value assuming no constant functions"""
        # Number of times val_fun was called
        # Use list to avoid creating local fun variable during assignment
        counter = [0]

        def val_fun(node, *inputs):
            counter[0] += 1
            if isinstance(node, spn.graph.node.VarNode):
                return 1
            elif isinstance(node, spn.graph.node.ParamNode):
                return 0.1
            else:
                weight_val, iv_val, *values = inputs
                return weight_val + sum(values) + 1

        # Generate graph
        v1 = spn.RawLeaf(num_vars=1)
        v2 = spn.RawLeaf(num_vars=1)
        v3 = spn.RawLeaf(num_vars=1)
        s1 = spn.Sum(v1, v1, v2)  # v1 included twice
        s2 = spn.Sum(v1, v3)
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        s4 = spn.Sum(s1, v1)
        s5 = spn.Sum(s2, v3, s3)
        s6 = spn.Sum(s4, s2, s5, s4, s5)  # s4 and s5 included twice
        spn.generate_weights(s6)

        # Calculate value
        val = spn.compute_graph_up(s6, val_fun)

        # Test
        self.assertAlmostEqual(val, 35.2)
        self.assertEqual(counter[0], 15)
    def test_gradient_on_dense_spn(self, num_decomps, num_subsets,
                                   num_mixtures, input_dist, num_vars,
                                   num_components, softplus):
        batch_size = 9

        mean_init = np.arange(num_vars * num_components).reshape(
            num_vars, num_components)
        gl = spn.GaussianLeaf(num_vars=num_vars,
                              num_components=num_components,
                              loc_init=mean_init,
                              softplus_scale=softplus)

        gen = spn.DenseSPNGenerator(
            num_decomps=num_decomps,
            num_subsets=num_subsets,
            num_mixtures=num_mixtures,
            node_type=spn.DenseSPNGenerator.NodeType.LAYER,
            input_dist=input_dist)

        root = gen.generate(gl, root_name="root")

        with tf.name_scope("Weights"):
            spn.generate_weights(root,
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=True)

        init = spn.initialize_weights(root)

        self.assertTrue(root.is_valid())

        log_val = root.get_log_value()

        spn_grad = spn.Gradient(log=True)

        spn_grad.get_gradients(root)

        mean_grad_custom, var_grad_custom = gl._compute_gradient(
            spn_grad.gradients[gl])

        mean_grad_tf, var_grad_tf = tf.gradients(
            log_val, [gl.loc_variable, gl.scale_variable])

        fd = {gl: np.random.rand(batch_size, num_vars)}

        with self.test_session() as sess:
            sess.run(init)
            mu_grad_custom_val, var_grad_custom_val = sess.run(
                [mean_grad_custom, var_grad_custom], fd)

            mu_grad_tf_val, var_grad_tf_val = sess.run(
                [mean_grad_tf, var_grad_tf], fd)

        self.assertAllClose(mu_grad_custom_val,
                            mu_grad_tf_val,
                            atol=1e-4,
                            rtol=1e-4)
        self.assertAllClose(var_grad_custom_val,
                            var_grad_tf_val,
                            atol=1e-4,
                            rtol=1e-4)
Esempio n. 5
0
    def mnist_01(inputs, num_decomps, num_subsets, num_mixtures, num_input_mixtures,
                 balanced, input_dist, node_type, inf_type, log=False):

        # Learning Parameters
        additive_smoothing = 100
        min_additive_smoothing = 1

        # Weight initialization
        weight_init_value = tf.initializers.random_uniform(10, 11)

        # Generate SPN structure
        dense_gen = spn.DenseSPNGenerator(num_decomps=num_decomps,
                                                    num_subsets=num_subsets,
                                                    num_mixtures=num_mixtures,
                                                    input_dist=(spn.DenseSPNGenerator.
                                                                InputDist.RAW if input_dist is
                                                                "RAW" else spn.
                                                                DenseSPNGenerator.
                                                                InputDist.MIXTURE),
                                                    num_input_mixtures=num_input_mixtures,
                                                    balanced=balanced,
                                                    node_type=node_type)
        root0 = dense_gen.generate(inputs, root_name="root_0")
        root1 = dense_gen.generate(inputs, root_name="root_1")
        root = spn.Sum(root0, root1, name="root")
        spn.generate_weights(root, initializer=weight_init_value)
        latent = root.generate_latent_indicators()

        # Add EM Learning
        additive_smoothing_var = tf.Variable(additive_smoothing, dtype=spn.conf.dtype)
        learning = spn.HardEMLearning(root, log=log, value_inference_type=inf_type,
                                  additive_smoothing=additive_smoothing_var)

        return root, latent, learning, additive_smoothing, min_additive_smoothing, \
            additive_smoothing_var
Esempio n. 6
0
    def test_get_num_nodes(self):
        """Computing the number of nodes in the SPN graph"""
        # Generate graph
        v1 = spn.ContVars(num_vars=1)
        v2 = spn.ContVars(num_vars=1)
        v3 = spn.ContVars(num_vars=1)
        s1 = spn.Sum(v1, v1, v2)  # v1 included twice
        s2 = spn.Sum(v1, v3)
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        s4 = spn.Sum(s1, v1)
        s5 = spn.Sum(s2, v3, s3)
        s6 = spn.Sum(s4, s2, s5, s4, s5)  # s4 and s5 included twice
        spn.generate_weights(s6)

        # Test
        num = v1.get_num_nodes(skip_params=True)
        self.assertEqual(num, 1)
        num = v1.get_num_nodes(skip_params=False)
        self.assertEqual(num, 1)

        num = v2.get_num_nodes(skip_params=True)
        self.assertEqual(num, 1)
        num = v2.get_num_nodes(skip_params=False)
        self.assertEqual(num, 1)

        num = v3.get_num_nodes(skip_params=True)
        self.assertEqual(num, 1)
        num = v3.get_num_nodes(skip_params=False)
        self.assertEqual(num, 1)

        num = s1.get_num_nodes(skip_params=True)
        self.assertEqual(num, 3)
        num = s1.get_num_nodes(skip_params=False)
        self.assertEqual(num, 4)

        num = s2.get_num_nodes(skip_params=True)
        self.assertEqual(num, 3)
        num = s2.get_num_nodes(skip_params=False)
        self.assertEqual(num, 4)

        num = s3.get_num_nodes(skip_params=True)
        self.assertEqual(num, 3)
        num = s3.get_num_nodes(skip_params=False)
        self.assertEqual(num, 4)

        num = s4.get_num_nodes(skip_params=True)
        self.assertEqual(num, 4)
        num = s4.get_num_nodes(skip_params=False)
        self.assertEqual(num, 6)

        num = s5.get_num_nodes(skip_params=True)
        self.assertEqual(num, 6)
        num = s5.get_num_nodes(skip_params=False)
        self.assertEqual(num, 9)

        num = s6.get_num_nodes(skip_params=True)
        self.assertEqual(num, 9)
        num = s6.get_num_nodes(skip_params=False)
        self.assertEqual(num, 15)
Esempio n. 7
0
    def mnist_all(inputs,
                  num_decomps,
                  num_subsets,
                  num_mixtures,
                  num_input_mixtures,
                  balanced,
                  input_dist,
                  node_type,
                  inf_type,
                  log=False):

        # Learning Parameters
        additive_smoothing = 0
        min_additive_smoothing = 0
        initial_accum_value = 20

        # Weight initialization
        weight_init_value = tf.initializers.random_uniform(0, 1)

        # Add random values before max
        add_random = None
        use_unweighted = True

        # Generate SPN structure
        dense_gen = spn.DenseSPNGenerator(
            num_decomps=num_decomps,
            num_subsets=num_subsets,
            num_mixtures=num_mixtures,
            input_dist=(spn.DenseSPNGenerator.InputDist.RAW
                        if input_dist is "RAW" else
                        spn.DenseSPNGenerator.InputDist.MIXTURE),
            num_input_mixtures=num_input_mixtures,
            balanced=balanced,
            node_type=node_type)
        class_roots = [
            dense_gen.generate(inputs, root_name=("Class_%d" % i))
            for i in range(10)
        ]
        root = spn.Sum(*class_roots, name="root")
        spn.generate_weights(root, init_value=weight_init_value)
        latent = root.generate_ivs()

        # Add EM Learning
        additive_smoothing_var = tf.Variable(additive_smoothing,
                                             dtype=spn.conf.dtype)
        learning = spn.EMLearning(root,
                                  log=log,
                                  value_inference_type=inf_type,
                                  additive_smoothing=additive_smoothing_var,
                                  add_random=add_random,
                                  initial_accum_value=initial_accum_value,
                                  use_unweighted=use_unweighted)

        return root, latent, learning, additive_smoothing, min_additive_smoothing, \
            additive_smoothing_var
Esempio n. 8
0
    def generic_dense_test(self, name, num_decomps, num_subsets, num_mixtures,
                           input_dist, num_input_mixtures):
        """A generic test for DenseSPNGenerator."""
        v1 = spn.IVs(num_vars=3, num_vals=2, name="IVs1")
        v2 = spn.IVs(num_vars=3, num_vals=2, name="IVs2")

        gen = spn.DenseSPNGenerator(num_decomps=num_decomps,
                                    num_subsets=num_subsets,
                                    num_mixtures=num_mixtures,
                                    input_dist=input_dist,
                                    num_input_mixtures=num_input_mixtures)

        # Generating SPN
        root = gen.generate(v1, v2)

        # Generating random weights
        with tf.name_scope("Weights"):
            spn.generate_weights(root,
                                 tf.initializers.random_uniform(0.0, 1.0))

        # Generating weight initializers
        init = spn.initialize_weights(root)

        # Testing validity
        self.assertTrue(root.is_valid())

        # Generating value ops
        v = root.get_value()
        v_log = root.get_log_value()

        # Creating session
        with self.test_session() as sess:
            # Initializing weights
            init.run()
            # Computing all values
            feed = np.array(list(itertools.product(range(2), repeat=6)))
            feed_v1 = feed[:, :3]
            feed_v2 = feed[:, 3:]
            out = sess.run(v, feed_dict={v1: feed_v1, v2: feed_v2})
            out_log = sess.run(tf.exp(v_log),
                               feed_dict={
                                   v1: feed_v1,
                                   v2: feed_v2
                               })
            # Test if partition function is 1.0
            self.assertAlmostEqual(out.sum(), 1.0, places=6)
            self.assertAlmostEqual(out_log.sum(), 1.0, places=6)
            self.write_tf_graph(sess, self.sid(), self.cid())
Esempio n. 9
0
    def poons_multi(inputs, num_vals, num_mixtures, num_subsets, inf_type,
                    log=False, output=None):

        # Build a POON-like network with multi-op nodes
        subsets = [spn.ParSums((inputs, list(range(i*num_vals, (i+1)*num_vals))),
                               num_sums=num_mixtures) for i in range(num_subsets)]
        products = spn.PermProducts(*subsets)
        root = spn.Sum(products, name="root")

        # Generate dense SPN and all weights in the network
        spn.generate_weights(root)

        # Generate value ops based on inf_type and log
        if log:
            value_op = root.get_log_value(inference_type=inf_type)
        else:
            value_op = root.get_value(inference_type=inf_type)

        return root, spn.initialize_weights(root), value_op
Esempio n. 10
0
    def dense_block(inputs,
                    num_decomps,
                    num_subsets,
                    num_mixtures,
                    num_input_mixtures,
                    balanced,
                    input_dist,
                    inf_type,
                    log=False):

        # Set node-type as single-node
        node_type = spn.DenseSPNGenerator.NodeType.BLOCK

        # Create a dense generator
        gen = spn.DenseSPNGenerator(
            num_decomps=num_decomps,
            num_subsets=num_subsets,
            num_mixtures=num_mixtures,
            num_input_mixtures=num_input_mixtures,
            balanced=balanced,
            node_type=node_type,
            input_dist=(spn.DenseSPNGenerator.InputDist.RAW
                        if input_dist is "RAW" else
                        spn.DenseSPNGenerator.InputDist.MIXTURE))

        # Generate a dense SPN, with single-op nodes, and all weights in the network
        root = gen.generate(inputs, root_name="root")
        spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0))

        # Generate path ops based on inf_type and log
        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [
            mpe_path_gen.counts[inp]
            for inp in (inputs if isinstance(inputs, list) else [inputs])
        ]

        return root, spn.initialize_weights(root), path_ops
Esempio n. 11
0
    def poon_single(inputs, num_vals, num_mixtures, num_subsets, inf_type,
                    log=False, output=None):

        # Build a POON-like network with single-op nodes
        subsets = [[spn.Sum((inputs, list(range(i*num_vals, (i+1)*num_vals))))
                   for _ in range(num_mixtures)] for i in range(num_subsets)]
        products = [spn.Product(*list(inp)) for inp in list(product(*[s for s in
                                                                      subsets]))]
        root = spn.Sum(*products, name="root")

        # Generate dense SPN and all weights in the network
        spn.generate_weights(root)

        # Generate value ops based on inf_type and log
        if log:
            value_op = root.get_log_value(inference_type=inf_type)
        else:
            value_op = root.get_value(inference_type=inf_type)

        return root, spn.initialize_weights(root), value_op
Esempio n. 12
0
    def test_traversing_on_dense(self):
        """Compare traversal algs on dense SPN"""
        def fun1(node, *args):
            counter[0] += 1

        def fun2(node, *args):
            counter[0] += 1
            if node.is_op:
                return [None] * len(node.inputs)

        # Generate dense graph
        v1 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf1")
        v2 = spn.IndicatorLeaf(num_vars=3, num_vals=2, name="IndicatorLeaf2")

        gen = spn.DenseSPNGenerator(num_decomps=2,
                                    num_subsets=3,
                                    num_mixtures=2,
                                    input_dist=spn.DenseSPNGenerator.InputDist.MIXTURE,
                                    num_input_mixtures=None)
        root = gen.generate(v1, v2)
        spn.generate_weights(root)

        # Run traversal algs and count nodes
        counter = [0]
        spn.compute_graph_up_down(root, down_fun=fun2, graph_input=1)
        c1 = counter[0]

        counter = [0]
        spn.compute_graph_up(root, val_fun=fun1)
        c2 = counter[0]

        counter = [0]
        spn.traverse_graph(root, fun=fun1, skip_params=False)
        c3 = counter[0]

        # Compare
        self.assertEqual(c1, c3)
        self.assertEqual(c2, c3)
Esempio n. 13
0
    def dense_block(inputs,
                    num_decomps,
                    num_subsets,
                    num_mixtures,
                    num_input_mixtures,
                    balanced,
                    input_dist,
                    inf_type,
                    log=False):

        # Set node-type as single-node
        node_type = spn.DenseSPNGenerator.NodeType.BLOCK

        # Create a dense generator
        gen = spn.DenseSPNGenerator(
            num_decomps=num_decomps,
            num_subsets=num_subsets,
            num_mixtures=num_mixtures,
            num_input_mixtures=num_input_mixtures,
            balanced=balanced,
            node_type=node_type,
            input_dist=(spn.DenseSPNGenerator.InputDist.RAW
                        if input_dist is "RAW" else
                        spn.DenseSPNGenerator.InputDist.MIXTURE))

        # Generate a dense SPN, with block-nodes, and all weights in the network
        root = gen.generate(inputs, root_name="root")
        spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0))

        # Generate value ops based on inf_type and log
        if log:
            value_op = root.get_log_value(inference_type=inf_type)
        else:
            value_op = root.get_value(inference_type=inf_type)

        return root, spn.initialize_weights(root), value_op
Esempio n. 14
0
    def test_get_nodes(self):
        """Obtaining the list of nodes in the SPN graph"""
        # Generate graph
        v1 = spn.ContVars(num_vars=1)
        v2 = spn.ContVars(num_vars=1)
        v3 = spn.ContVars(num_vars=1)
        s1 = spn.Sum(v1, v1, v2)  # v1 included twice
        s2 = spn.Sum(v1, v3)
        s3 = spn.Sum(v2, v3, v3)  # v3 included twice
        s4 = spn.Sum(s1, v1)
        s5 = spn.Sum(s2, v3, s3)
        s6 = spn.Sum(s4, s2, s5, s4, s5)  # s4 and s5 included twice
        spn.generate_weights(s6)

        # Test
        nodes = v1.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [v1])
        nodes = v1.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [v1])

        nodes = v2.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [v2])
        nodes = v2.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [v2])

        nodes = v3.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [v3])
        nodes = v3.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [v3])

        nodes = s1.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [s1, v1, v2])
        nodes = s1.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [s1, s1.weights.node, v1, v2])

        nodes = s2.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [s2, v1, v3])
        nodes = s2.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [s2, s2.weights.node, v1, v3])

        nodes = s3.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [s3, v2, v3])
        nodes = s3.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [s3, s3.weights.node, v2, v3])

        nodes = s4.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [s4, s1, v1, v2])
        nodes = s4.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [s4, s4.weights.node, s1, v1,
                                     s1.weights.node, v2])

        nodes = s5.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [s5, s2, v3, s3, v1, v2])
        nodes = s5.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [s5, s5.weights.node, s2, v3, s3,
                                     s2.weights.node, v1, s3.weights.node, v2])

        nodes = s6.get_nodes(skip_params=True)
        self.assertListEqual(nodes, [s6, s4, s2, s5, s1, v1, v3, s3, v2])
        nodes = s6.get_nodes(skip_params=False)
        self.assertListEqual(nodes, [s6, s6.weights.node, s4, s2, s5,
                                     s4.weights.node, s1, v1, s2.weights.node,
                                     v3, s5.weights.node, s3, s1.weights.node,
                                     v2, s3.weights.node])
Esempio n. 15
0
    def test_compute_graph_down(self):
        counter = [0]
        parent_vals_saved = {}

        def fun(node, parent_vals):
            parent_vals_saved[node] = parent_vals
            val = sum(parent_vals) + 0.01
            counter[0] += 1
            if node.is_op:
                return [val + i for i, _ in enumerate(node.inputs)]
            else:
                return 101

        # Generate graph
        v1 = spn.RawLeaf(num_vars=1, name="v1")
        v2 = spn.RawLeaf(num_vars=1, name="v2")
        v3 = spn.RawLeaf(num_vars=1, name="v3")
        s1 = spn.Sum(v1, v1, v2, name="s1")  # v1 included twice
        s2 = spn.Sum(v1, v3, name="s2")
        s3 = spn.Sum(v2, v3, v3, name="s3")  # v3 included twice
        s4 = spn.Sum(s1, v1, name="s4")
        s5 = spn.Sum(s2, v3, s3, name="s5")
        s6 = spn.Sum(s4, s2, s5, s4, s5, name="s6")  # s4 and s5 included twice
        spn.generate_weights(s6)

        down_values = {}
        spn.compute_graph_up_down(s6, down_fun=fun, graph_input=5,
                                  down_values=down_values)

        self.assertEqual(counter[0], 15)
        # Using sorted since order is not guaranteed
        self.assertListAlmostEqual(sorted(parent_vals_saved[s6]), [5])
        self.assertListAlmostEqual(down_values[s6], [5.01, 6.01, 7.01, 8.01,
                                                     9.01, 10.01, 11.01])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s5]), [9.01, 11.01])
        self.assertListAlmostEqual(down_values[s5], [20.03, 21.03, 22.03,
                                                     23.03, 24.03])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s4]), [7.01, 10.01])
        self.assertListAlmostEqual(down_values[s4], [17.03, 18.03, 19.03, 20.03])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s3]), [24.03])
        self.assertListAlmostEqual(down_values[s3], [24.04, 25.04, 26.04,
                                                     27.04, 28.04])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s2]), [8.01, 22.03])
        self.assertListAlmostEqual(down_values[s2], [30.05, 31.05, 32.05, 33.05])
        self.assertListAlmostEqual(sorted(parent_vals_saved[s1]), [19.03])
        self.assertListAlmostEqual(down_values[s1], [19.04, 20.04, 21.04,
                                                     22.04, 23.04])

        self.assertListAlmostEqual(sorted(parent_vals_saved[v1]),
                                   [20.03, 21.04, 22.04, 32.05])
        self.assertEqual(down_values[v1], 101)
        self.assertListAlmostEqual(sorted(parent_vals_saved[v2]),
                                   [23.04, 26.04])
        self.assertEqual(down_values[v2], 101)
        self.assertListAlmostEqual(sorted(parent_vals_saved[v3]),
                                   [23.03, 27.04, 28.04, 33.05])
        self.assertEqual(down_values[v3], 101)

        # Test if the algorithm works on a VarNode and calls graph_input function
        down_values = {}
        parent_vals_saved = {}
        spn.compute_graph_up_down(v1, down_fun=fun, graph_input=lambda: 5,
                                  down_values=down_values)
        self.assertEqual(parent_vals_saved[v1][0], 5)
        self.assertEqual(down_values[v1], 101)
Esempio n. 16
0
    def test_generate_spn(self, num_decomps, num_subsets, num_mixtures,
                          num_input_mixtures, input_dims, input_dist, balanced,
                          node_type, log_weights):
        """A generic test for DenseSPNGenerator."""

        if input_dist == spn.DenseSPNGenerator.InputDist.RAW \
            and num_input_mixtures != 1:
            # Redundant test case, so just return
            return

        # Input parameters
        num_inputs = input_dims[0]
        num_vars = input_dims[1]
        num_vals = 2

        printc("\n- num_inputs: %s" % num_inputs)
        printc("- num_vars: %s" % num_vars)
        printc("- num_vals: %s" % num_vals)
        printc("- num_decomps: %s" % num_decomps)
        printc("- num_subsets: %s" % num_subsets)
        printc("- num_mixtures: %s" % num_mixtures)
        printc("- input_dist: %s" %
               ("MIXTURE" if input_dist
                == spn.DenseSPNGenerator.InputDist.MIXTURE else "RAW"))
        printc("- balanced: %s" % balanced)
        printc("- num_input_mixtures: %s" % num_input_mixtures)
        printc("- node_type: %s" %
               ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE
                else "BLOCK" if node_type
                == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER"))
        printc("- log_weights: %s" % log_weights)

        # Inputs
        inputs = [
            spn.IVs(num_vars=num_vars,
                    num_vals=num_vals,
                    name=("IVs_%d" % (i + 1))) for i in range(num_inputs)
        ]

        gen = spn.DenseSPNGenerator(num_decomps=num_decomps,
                                    num_subsets=num_subsets,
                                    num_mixtures=num_mixtures,
                                    input_dist=input_dist,
                                    balanced=balanced,
                                    num_input_mixtures=num_input_mixtures,
                                    node_type=node_type)

        # Generate Sub-SPNs
        sub_spns = [
            gen.generate(*inputs, root_name=("sub_root_%d" % (i + 1)))
            for i in range(3)
        ]

        # Generate random weights for the first sub-SPN
        with tf.name_scope("Weights"):
            spn.generate_weights(sub_spns[0],
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=log_weights)

        # Initialize weights of the first sub-SPN
        sub_spn_init = spn.initialize_weights(sub_spns[0])

        # Testing validity of the first sub-SPN
        self.assertTrue(sub_spns[0].is_valid())

        # Generate value ops of the first sub-SPN
        sub_spn_v = sub_spns[0].get_value()
        sub_spn_v_log = sub_spns[0].get_log_value()

        # Generate path ops of the first sub-SPN
        sub_spn_mpe_path_gen = spn.MPEPath(log=False)
        sub_spn_mpe_path_gen_log = spn.MPEPath(log=True)
        sub_spn_mpe_path_gen.get_mpe_path(sub_spns[0])
        sub_spn_mpe_path_gen_log.get_mpe_path(sub_spns[0])
        sub_spn_path = [sub_spn_mpe_path_gen.counts[inp] for inp in inputs]
        sub_spn_path_log = [
            sub_spn_mpe_path_gen_log.counts[inp] for inp in inputs
        ]

        # Collect all weight nodes of the first sub-SPN
        sub_spn_weight_nodes = []

        def fun(node):
            if node.is_param:
                sub_spn_weight_nodes.append(node)

        spn.traverse_graph(sub_spns[0], fun=fun)

        # Generate an upper-SPN over sub-SPNs
        products_lower = []
        for sub_spn in sub_spns:
            products_lower.append([v.node for v in sub_spn.values])

        num_top_mixtures = [2, 1, 3]
        sums_lower = []
        for prods, num_top_mix in zip(products_lower, num_top_mixtures):
            if node_type == spn.DenseSPNGenerator.NodeType.SINGLE:
                sums_lower.append(
                    [spn.Sum(*prods) for _ in range(num_top_mix)])
            elif node_type == spn.DenseSPNGenerator.NodeType.BLOCK:
                sums_lower.append([spn.ParSums(*prods, num_sums=num_top_mix)])
            else:
                sums_lower.append([
                    spn.SumsLayer(*prods * num_top_mix,
                                  num_or_size_sums=num_top_mix)
                ])

        # Generate upper-SPN
        root = gen.generate(*list(itertools.chain(*sums_lower)),
                            root_name="root")

        # Generate random weights for the SPN
        with tf.name_scope("Weights"):
            spn.generate_weights(root,
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=log_weights)

        # Initialize weight of the SPN
        spn_init = spn.initialize_weights(root)

        # Testing validity of the SPN
        self.assertTrue(root.is_valid())

        # Generate value ops of the SPN
        spn_v = root.get_value()
        spn_v_log = root.get_log_value()

        # Generate path ops of the SPN
        spn_mpe_path_gen = spn.MPEPath(log=False)
        spn_mpe_path_gen_log = spn.MPEPath(log=True)
        spn_mpe_path_gen.get_mpe_path(root)
        spn_mpe_path_gen_log.get_mpe_path(root)
        spn_path = [spn_mpe_path_gen.counts[inp] for inp in inputs]
        spn_path_log = [spn_mpe_path_gen_log.counts[inp] for inp in inputs]

        # Collect all weight nodes in the SPN
        spn_weight_nodes = []

        def fun(node):
            if node.is_param:
                spn_weight_nodes.append(node)

        spn.traverse_graph(root, fun=fun)

        # Create a session
        with self.test_session() as sess:
            # Initializing weights
            sess.run(sub_spn_init)
            sess.run(spn_init)

            # Generate input feed
            feed = np.array(
                list(
                    itertools.product(range(num_vals),
                                      repeat=(num_inputs * num_vars))))
            batch_size = feed.shape[0]
            feed_dict = {}
            for inp, f in zip(inputs, np.split(feed, num_inputs, axis=1)):
                feed_dict[inp] = f

            # Compute all values and paths of sub-SPN
            sub_spn_out = sess.run(sub_spn_v, feed_dict=feed_dict)
            sub_spn_out_log = sess.run(tf.exp(sub_spn_v_log),
                                       feed_dict=feed_dict)
            sub_spn_out_path = sess.run(sub_spn_path, feed_dict=feed_dict)
            sub_spn_out_path_log = sess.run(sub_spn_path_log,
                                            feed_dict=feed_dict)

            # Compute all values and paths of the complete SPN
            spn_out = sess.run(spn_v, feed_dict=feed_dict)
            spn_out_log = sess.run(tf.exp(spn_v_log), feed_dict=feed_dict)
            spn_out_path = sess.run(spn_path, feed_dict=feed_dict)
            spn_out_path_log = sess.run(spn_path_log, feed_dict=feed_dict)

            # Test if partition function of the sub-SPN and of the
            # complete SPN is 1.0
            self.assertAlmostEqual(sub_spn_out.sum(), 1.0, places=6)
            self.assertAlmostEqual(sub_spn_out_log.sum(), 1.0, places=6)
            self.assertAlmostEqual(spn_out.sum(), 1.0, places=6)
            self.assertAlmostEqual(spn_out_log.sum(), 1.0, places=6)

            # Test if the sum of counts for each value of each variable
            # (6 variables, with 2 values each) = batch-size / num-vals
            self.assertEqual(
                np.sum(np.hstack(sub_spn_out_path), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(sub_spn_out_path_log), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(spn_out_path), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(spn_out_path_log), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)