Beispiel #1
0
    def par_sums(inputs,
                 indices,
                 latent_indicators,
                 num_sums,
                 inf_type=None,
                 log=True,
                 output=None):
        if indices is None:
            inputs = [inputs]
        else:
            inputs = [(inputs, indices)]

        # Generate a single ParallelSums node, modeling 'num_sums' sum nodes
        # within, connecting it to inputs and latent_indicators
        s = spn.ParallelSums(*inputs,
                             num_sums=num_sums,
                             latent_indicators=latent_indicators[-1])
        # Generate weights of the ParallelSums node
        weights = s.generate_weights()

        # Connect the ParallelSums nodes to a single root Sum node and generate
        # its weights
        root = spn.Sum(s)
        root.generate_weights()

        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_op = [mpe_path_gen.counts[weights]]
        return spn.initialize_weights(root), path_op
Beispiel #2
0
    def test_mpe_path(self):
        # Generate SPN
        model = spn.Poon11NaiveMixtureModel()
        model.build()
        # Add ops
        init = spn.initialize_weights(model.root)
        mpe_path_gen = spn.MPEPath(value_inference_type=spn.InferenceType.MPE,
                                   log=False)
        mpe_path_gen_log = spn.MPEPath(
            value_inference_type=spn.InferenceType.MPE, log=True)
        mpe_path_gen.get_mpe_path(model.root)
        mpe_path_gen_log.get_mpe_path(model.root)
        # Run
        with self.test_session() as sess:
            init.run()
            out = sess.run(mpe_path_gen.counts[model.latent_indicators],
                           feed_dict={model.latent_indicators: model.feed})
            out_log = sess.run(
                mpe_path_gen_log.counts[model.latent_indicators],
                feed_dict={model.latent_indicators: model.feed})

        true_latent_indicators_counts = np.array(
            [[0., 1., 1., 0.], [0., 1., 1., 0.], [0., 1., 0., 1.],
             [1., 0., 1., 0.], [1., 0., 1., 0.], [1., 0., 0., 1.],
             [0., 1., 1., 0.], [0., 1., 1., 0.], [0., 1., 0., 1.]],
            dtype=spn.conf.dtype.as_numpy_dtype)

        np.testing.assert_array_equal(out, true_latent_indicators_counts)
        np.testing.assert_array_equal(out_log, true_latent_indicators_counts)
Beispiel #3
0
    def poons_multi(inputs,
                    num_vals,
                    num_mixtures,
                    num_subsets,
                    inf_type,
                    log=False,
                    output=None):

        # Build a POON-like network with multi-op nodes
        subsets = [
            spn.ParSums((inputs, list(range(i * num_vals,
                                            (i + 1) * num_vals))),
                        num_sums=num_mixtures) for i in range(num_subsets)
        ]
        products = spn.PermProducts(*subsets)
        root = spn.Sum(products, name="root")

        # Generate dense SPN and all weights in the network
        spn.generate_weights(root)

        # Generate path ops based on inf_type and log
        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [
            mpe_path_gen.counts[inp]
            for inp in (inputs if isinstance(inputs, list) else [inputs])
        ]
        return root, spn.initialize_weights(root), path_ops
Beispiel #4
0
    def sum(inputs,
            indices,
            latent_indicators,
            num_sums,
            inf_type=None,
            log=False,
            output=None):
        if indices is None:
            inputs = [inputs]
        else:
            inputs = [(inputs, indices)]

        # Generate 'num_sums' Sum nodes, connecting each to inputs and latent_indicators
        s = []
        weights = []
        for i in range(0, num_sums):
            s = s + [spn.Sum(*inputs, latent_indicators=latent_indicators[i])]
            weights = weights + [s[-1].generate_weights()]

        # Connect all sum nodes to a single root Sum node and generate its weights
        root = spn.Sum(*s)
        root.generate_weights()

        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [mpe_path_gen.counts[w] for w in weights]
        return spn.initialize_weights(root), path_ops
Beispiel #5
0
    def sums(inputs,
             indices,
             ivs,
             num_sums,
             inf_type=None,
             log=True,
             output=None):
        if indices is None:
            inputs = [inputs for _ in range(num_sums)]
        else:
            inputs = [(inputs, indices) for _ in range(num_sums)]

        # Generate a single Sums node, modeling 'num_sums' sum nodes within,
        # connecting it to inputs and ivs
        s = spn.Sums(*inputs, num_sums=num_sums, ivs=ivs[-1])
        # Generate weights of the Sums node
        weights = s.generate_weights()

        # Connect the Sums nodes to a single root Sum node and generate its weights
        root = spn.Sum(s)
        root.generate_weights()

        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_op = [mpe_path_gen.counts[weights]]
        return spn.initialize_weights(root), path_op
    def products(inputs,
                 num_inputs,
                 num_input_cols,
                 num_prods,
                 inf_type,
                 indices=None,
                 log=False,
                 output=None):
        p = []
        # Generate 'len(inputs)' Products node, modelling 'n_prods' ∈ 'num_prods'
        # products within each
        for inps, n_inp_cols, n_prods in zip(inputs, num_input_cols,
                                             num_prods):
            num_inputs = len(inps)
            # Create permuted indices based on number and size of inps
            inds = map(int, np.arange(n_inp_cols))
            permuted_inds = list(product(inds, repeat=num_inputs))
            permuted_inds_list = [list(elem) for elem in permuted_inds]
            permuted_inds_list_of_list = []
            for elem in permuted_inds_list:
                permuted_inds_list_of_list.append(
                    [elem[i:i + 1] for i in range(0, len(elem), 1)])

            # Create inputs-list by combining inps and indices
            permuted_inputs = []
            for indices in permuted_inds_list_of_list:
                permuted_inputs.append([tuple(i) for i in zip(inps, indices)])
            permuted_inputs = list(chain.from_iterable(permuted_inputs))

            # Generate a single Products node, modeling 'n_prods' product nodes
            # within, connecting it to inputs
            p = p + [spn.Products(*permuted_inputs, num_prods=n_prods)]

        # Connect all product nodes to a single root Sum node and generate its
        # weights
        root = spn.Sum(*p)
        root.generate_weights()

        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [
            mpe_path_gen.counts[inp]
            for inp in list(chain.from_iterable(inputs))
        ]
        return spn.initialize_weights(root), path_ops
Beispiel #7
0
    def perm_products(inputs, num_inputs, num_input_cols, num_prods, inf_type,
                      indices=None, log=False, output=None):
        if indices is not None:
            # Create inputs list with indices
            inputs_list = [[(inp, ind) for inp, ind in zip(inps, inds)] for inps,
                           inds in zip(inputs, indices)]
        else:
            inputs_list = inputs

        if isinstance(inputs, list):  # Is a list of ContVars inputs - Multiple inputs
            # Generate 'len(inputs)' PermProducts nodes, modeling 'n_prods' products
            # within each
            p = [spn.PermProducts(*inps) for inps in inputs]
        else:  # Is a single input of type ContVars - A single input
            num_inputs_array = np.array(num_inputs)
            num_input_cols_array = np.array(num_input_cols)
            num_cols = num_input_cols[0]
            num_vars = int(np.sum(num_inputs_array * num_input_cols_array))

            indices_list = [list(range(i, i+num_cols)) for i in range(0, num_vars,
                                                                      num_cols)]
            num_inputs_cumsum = np.cumsum(num_inputs_array).tolist()
            num_inputs_cumsum.insert(0, 0)

            inputs_list = [[(inputs, inds) for inds in indices_list[start:stop]]
                           for start, stop in zip(num_inputs_cumsum[:-1],
                                                  num_inputs_cumsum[1:])]

            # Generate 'len(inputs)' PermProducts nodes, modeling 'n_prods'
            # products within each, and inputs for each node emination from a
            # commoninput source
            p = [spn.PermProducts(*inps) for inps in inputs_list]

        # Connect all PermProducts nodes to a single root Sum node and generate
        # its weights
        root = spn.Sum(*p)
        root.generate_weights()

        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=False)

        mpe_path_gen.get_mpe_path(root)
        if isinstance(inputs, list):  # Is a list of ContVars inputs - Multiple inputs
            path_ops = [mpe_path_gen.counts[inp] for inp in
                        list(chain.from_iterable(inputs))]
        else:  # Is a single input of type ContVars - A single input
            path_ops = mpe_path_gen.counts[inputs]
        return spn.initialize_weights(root), path_ops
Beispiel #8
0
    def _build_op(self, inputs, placeholders, conf):
        """ Creates the graph using only ParSum nodes """
        # TODO make sure the ivs are correct
        sum_indices, weights, ivs = inputs.indices, inputs.weights, None
        log, inf_type = conf.log, conf.inf_type
        weights = np.split(
            weights,
            np.cumsum([len(ind) * inputs.num_parallel
                       for ind in sum_indices])[:-1])

        parallel_sum_nodes = []
        for ind in sum_indices:
            parallel_sum_nodes.append(
                spn.ParSums((placeholders[0], ind),
                            num_sums=inputs.num_parallel))

        weight_nodes = [
            self._generate_weights(node, w.tolist())
            for node, w in zip(parallel_sum_nodes, weights)
        ]
        if ivs:
            [s.set_ivs(iv) for s, iv in zip(parallel_sum_nodes, ivs)]
        root = spn.Sum(*parallel_sum_nodes)
        self._generate_weights(root)

        mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=log)
        mpe_path_gen.get_mpe_path(root)
        path_op = [mpe_path_gen.counts[w] for w in weight_nodes]
        input_counts = [mpe_path_gen.counts[inp] for inp in placeholders]

        return tf.tuple(
            path_op + input_counts)[:len(path_op)], self._initialize_from(root)
Beispiel #9
0
    def dense_block(inputs,
                    num_decomps,
                    num_subsets,
                    num_mixtures,
                    num_input_mixtures,
                    balanced,
                    input_dist,
                    inf_type,
                    log=False):

        # Set node-type as single-node
        node_type = spn.DenseSPNGenerator.NodeType.BLOCK

        # Create a dense generator
        gen = spn.DenseSPNGenerator(
            num_decomps=num_decomps,
            num_subsets=num_subsets,
            num_mixtures=num_mixtures,
            num_input_mixtures=num_input_mixtures,
            balanced=balanced,
            node_type=node_type,
            input_dist=(spn.DenseSPNGenerator.InputDist.RAW
                        if input_dist is "RAW" else
                        spn.DenseSPNGenerator.InputDist.MIXTURE))

        # Generate a dense SPN, with single-op nodes, and all weights in the network
        root = gen.generate(inputs, root_name="root")
        spn.generate_weights(root, tf.initializers.random_uniform(0.0, 1.0))

        # Generate path ops based on inf_type and log
        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [
            mpe_path_gen.counts[inp]
            for inp in (inputs if isinstance(inputs, list) else [inputs])
        ]

        return root, spn.initialize_weights(root), path_ops
Beispiel #10
0
    def _build_op(self, inputs, placeholders, conf):
        # TODO make sure the ivs are correct
        sum_indices, weights, ivs = inputs.indices, inputs.weights, None
        log, inf_type = conf.log, conf.inf_type
        repeated_inputs = []
        repeated_sum_sizes = []
        offset = 0
        for ind in sum_indices:
            # Indices are given by looking at the sizes of the sums
            size = len(ind)
            repeated_inputs.extend([(placeholders[0], ind)
                                    for _ in range(inputs.num_parallel)])
            repeated_sum_sizes.extend(
                [size for _ in range(inputs.num_parallel)])
            offset += size

        # Globally configure to add up the sums before passing on the values to children
        spn.conf.sumslayer_count_sum_strategy = self.sum_count_strategy
        sums_layer = spn.SumsLayer(*repeated_inputs,
                                   num_or_size_sums=repeated_sum_sizes)
        weight_node = self._generate_weights(sums_layer, weights)
        if ivs:
            sums_layer.set_ivs(*ivs)
        # Connect a single sum to group outcomes
        root = spn.Sum(sums_layer)

        self._generate_weights(root)
        # Then build MPE path Ops
        mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=log)
        mpe_path_gen.get_mpe_path(root)
        path_op = [
            tf.tuple([
                mpe_path_gen.counts[weight_node],
                mpe_path_gen.counts[placeholders[0]]
            ])[0]
        ]

        return path_op, self._initialize_from(root)
Beispiel #11
0
    def test_sumslayer_mpe_path(self, input_sizes, sum_sizes,
                                latent_indicators, log, same_inputs, inf_type,
                                count_strategy, indices, use_unweighted):
        spn.conf.argmax_zero = True

        # Set some defaults
        if (1 in sum_sizes or 1 in input_sizes or np.all(np.equal(sum_sizes, sum_sizes[0]))) \
                and use_unweighted:
            # There is not a clean way to solve the issue avoided here. It has to do with floating
            # point errors in numpy vs. tf, leading to unpredictable behavior of argmax.
            # Unweighted values take away any weighting randomness, so the argmax will obtain some
            # values that are very likely to be equal up to these floating point errors. Hence,
            # we just set use_unweighted to False if the sum size or input size equals 1 (which is
            # typically when the values are 'pseudo'-equal)
            return None

        batch_size = 32
        factor = 10
        # Configure count strategy
        spn.conf.sumslayer_count_sum_strategy = count_strategy
        feed_dict, indices, input_nodes, input_tuples, latent_indicators, values, weights, root_weights = \
            self.sumslayer_prepare_common(
                batch_size, factor, indices, input_sizes, latent_indicators, same_inputs, sum_sizes)

        root_weights_np = np.ones_like(
            root_weights) if use_unweighted and log else root_weights
        weight_counts, latent_indicators_counts, value_counts = sumslayer_mpe_path_numpy(
            values, indices, weights,
            None if not latent_indicators else latent_indicators, sum_sizes,
            inf_type, root_weights_np)
        # Build graph
        init, latent_indicators_nodes, root, weight_node = self.build_sumslayer_common(
            feed_dict, input_tuples, latent_indicators, sum_sizes, weights,
            root_weights)

        # Then build MPE path Ops
        mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                   log=True,
                                   use_unweighted=use_unweighted)
        mpe_path_gen.get_mpe_path(root)
        path_op = [
            mpe_path_gen.counts[node]
            for node in [weight_node] + input_nodes + latent_indicators_nodes
        ]

        # Run graph and do some post-processing
        with self.test_session() as sess:
            sess.run(init)
            out = sess.run(path_op, feed_dict=feed_dict)
            if latent_indicators:
                latent_indicators_counts_out = out[-1]
                latent_indicators_counts_out = np.split(
                    latent_indicators_counts_out,
                    indices_or_sections=len(sum_sizes),
                    axis=1)
                latent_indicators_counts_out = [
                    np.squeeze(iv, axis=1)[:, :size] for iv, size in zip(
                        latent_indicators_counts_out, sum_sizes)
                ]
                out = out[:-1]
            weight_counts_out, *input_counts_out = out
            weight_counts_out = np.split(weight_counts_out,
                                         indices_or_sections=len(sum_sizes),
                                         axis=1)
            weight_counts_out = [
                np.squeeze(w, axis=1)[:, :size]
                for w, size in zip(weight_counts_out, sum_sizes)
            ]
        if same_inputs:
            value_counts = [np.sum(value_counts, axis=0)]

        # Test outputs
        [
            self.assertAllClose(inp_count_out, inp_count)
            for inp_count_out, inp_count in zip(input_counts_out, value_counts)
        ]
        [
            self.assertAllClose(w_out, w_out_truth)
            for w_out, w_out_truth in zip(weight_counts_out, weight_counts)
        ]
        if latent_indicators:
            [
                self.assertAllClose(iv_out, iv_true_out)
                for iv_out, iv_true_out in zip(latent_indicators_counts_out,
                                               latent_indicators_counts)
            ]
Beispiel #12
0
    def products_layer(inputs,
                       num_inputs,
                       num_input_cols,
                       num_prods,
                       inf_type,
                       indices=None,
                       log=False,
                       output=None):
        products_inputs = []
        num_or_size_prods = []
        if isinstance(inputs,
                      list):  # Is a list of RawLeaf inputs - Multiple inputs
            for inps, n_inp_cols, n_prods in zip(inputs, num_input_cols,
                                                 num_prods):
                num_inputs = len(inps)
                # Create permuted indices based on number and size of inputs
                inds = map(int, np.arange(n_inp_cols))
                permuted_inds = list(product(inds, repeat=num_inputs))
                permuted_inds_list = [list(elem) for elem in permuted_inds]
                permuted_inds_list_of_list = []
                for elem in permuted_inds_list:
                    permuted_inds_list_of_list.append(
                        [elem[i:i + 1] for i in range(0, len(elem), 1)])

                # Create inputs list by combining inputs and indices
                permuted_inputs = []
                for indices in permuted_inds_list_of_list:
                    permuted_inputs.append(
                        [tuple(i) for i in zip(inps, indices)])
                products_inputs += list(chain.from_iterable(permuted_inputs))

                # Create products-size list
                num_or_size_prods += [num_inputs] * n_prods
        else:  # Is a single input of type RawLeaf - A single input
            outer_offset = 0
            permuted_inds_list = []
            for n_inps, n_inp_cols in zip(num_inputs, num_input_cols):
                # Create permuted indices based on number and size of inputs
                inds = map(int, np.arange(n_inp_cols))
                permuted_inds = list(product(inds, repeat=n_inps))
                offsets = \
                    np.array(list(range(0, (n_inps * n_inp_cols), n_inp_cols))) \
                    + outer_offset
                outer_offset += n_inps * n_inp_cols
                for perm_inds in permuted_inds:
                    permuted_inds_list.append([
                        p_ind + off
                        for p_ind, off in zip(list(perm_inds), offsets)
                    ])

            # Content of list object 'perm_inds' needs to be of type int, if not
            # input_parser in Input class complains
            products_inputs = [(inputs, list(map(int, perm_inds)))
                               for perm_inds in permuted_inds_list]
            num_or_size_prods = [
                len(perm_inds) for perm_inds in permuted_inds_list
            ]

        # Generate a single ProductsLayer node, modeling 'sum(num_prods)' products
        # within, connecting it to inputs
        p = spn.ProductsLayer(*products_inputs,
                              num_or_size_prods=num_or_size_prods)

        # Connect all product nodes to a single root Sum node and generate its
        # weights
        root = spn.Sum(p)
        root.generate_weights()

        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        if isinstance(inputs, list):
            path_ops = [
                mpe_path_gen.counts[inp]
                for inp in list(chain.from_iterable(inputs))
            ]
        else:
            path_ops = mpe_path_gen.counts[inputs]
        return spn.initialize_weights(root), path_ops
Beispiel #13
0
    def test_generate_spn(self, num_decomps, num_subsets, num_mixtures,
                          num_input_mixtures, input_dims, input_dist, balanced,
                          node_type, log_weights):
        """A generic test for DenseSPNGenerator."""

        if input_dist == spn.DenseSPNGenerator.InputDist.RAW \
            and num_input_mixtures != 1:
            # Redundant test case, so just return
            return

        # Input parameters
        num_inputs = input_dims[0]
        num_vars = input_dims[1]
        num_vals = 2

        printc("\n- num_inputs: %s" % num_inputs)
        printc("- num_vars: %s" % num_vars)
        printc("- num_vals: %s" % num_vals)
        printc("- num_decomps: %s" % num_decomps)
        printc("- num_subsets: %s" % num_subsets)
        printc("- num_mixtures: %s" % num_mixtures)
        printc("- input_dist: %s" %
               ("MIXTURE" if input_dist
                == spn.DenseSPNGenerator.InputDist.MIXTURE else "RAW"))
        printc("- balanced: %s" % balanced)
        printc("- num_input_mixtures: %s" % num_input_mixtures)
        printc("- node_type: %s" %
               ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE
                else "BLOCK" if node_type
                == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER"))
        printc("- log_weights: %s" % log_weights)

        # Inputs
        inputs = [
            spn.IVs(num_vars=num_vars,
                    num_vals=num_vals,
                    name=("IVs_%d" % (i + 1))) for i in range(num_inputs)
        ]

        gen = spn.DenseSPNGenerator(num_decomps=num_decomps,
                                    num_subsets=num_subsets,
                                    num_mixtures=num_mixtures,
                                    input_dist=input_dist,
                                    balanced=balanced,
                                    num_input_mixtures=num_input_mixtures,
                                    node_type=node_type)

        # Generate Sub-SPNs
        sub_spns = [
            gen.generate(*inputs, root_name=("sub_root_%d" % (i + 1)))
            for i in range(3)
        ]

        # Generate random weights for the first sub-SPN
        with tf.name_scope("Weights"):
            spn.generate_weights(sub_spns[0],
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=log_weights)

        # Initialize weights of the first sub-SPN
        sub_spn_init = spn.initialize_weights(sub_spns[0])

        # Testing validity of the first sub-SPN
        self.assertTrue(sub_spns[0].is_valid())

        # Generate value ops of the first sub-SPN
        sub_spn_v = sub_spns[0].get_value()
        sub_spn_v_log = sub_spns[0].get_log_value()

        # Generate path ops of the first sub-SPN
        sub_spn_mpe_path_gen = spn.MPEPath(log=False)
        sub_spn_mpe_path_gen_log = spn.MPEPath(log=True)
        sub_spn_mpe_path_gen.get_mpe_path(sub_spns[0])
        sub_spn_mpe_path_gen_log.get_mpe_path(sub_spns[0])
        sub_spn_path = [sub_spn_mpe_path_gen.counts[inp] for inp in inputs]
        sub_spn_path_log = [
            sub_spn_mpe_path_gen_log.counts[inp] for inp in inputs
        ]

        # Collect all weight nodes of the first sub-SPN
        sub_spn_weight_nodes = []

        def fun(node):
            if node.is_param:
                sub_spn_weight_nodes.append(node)

        spn.traverse_graph(sub_spns[0], fun=fun)

        # Generate an upper-SPN over sub-SPNs
        products_lower = []
        for sub_spn in sub_spns:
            products_lower.append([v.node for v in sub_spn.values])

        num_top_mixtures = [2, 1, 3]
        sums_lower = []
        for prods, num_top_mix in zip(products_lower, num_top_mixtures):
            if node_type == spn.DenseSPNGenerator.NodeType.SINGLE:
                sums_lower.append(
                    [spn.Sum(*prods) for _ in range(num_top_mix)])
            elif node_type == spn.DenseSPNGenerator.NodeType.BLOCK:
                sums_lower.append([spn.ParSums(*prods, num_sums=num_top_mix)])
            else:
                sums_lower.append([
                    spn.SumsLayer(*prods * num_top_mix,
                                  num_or_size_sums=num_top_mix)
                ])

        # Generate upper-SPN
        root = gen.generate(*list(itertools.chain(*sums_lower)),
                            root_name="root")

        # Generate random weights for the SPN
        with tf.name_scope("Weights"):
            spn.generate_weights(root,
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=log_weights)

        # Initialize weight of the SPN
        spn_init = spn.initialize_weights(root)

        # Testing validity of the SPN
        self.assertTrue(root.is_valid())

        # Generate value ops of the SPN
        spn_v = root.get_value()
        spn_v_log = root.get_log_value()

        # Generate path ops of the SPN
        spn_mpe_path_gen = spn.MPEPath(log=False)
        spn_mpe_path_gen_log = spn.MPEPath(log=True)
        spn_mpe_path_gen.get_mpe_path(root)
        spn_mpe_path_gen_log.get_mpe_path(root)
        spn_path = [spn_mpe_path_gen.counts[inp] for inp in inputs]
        spn_path_log = [spn_mpe_path_gen_log.counts[inp] for inp in inputs]

        # Collect all weight nodes in the SPN
        spn_weight_nodes = []

        def fun(node):
            if node.is_param:
                spn_weight_nodes.append(node)

        spn.traverse_graph(root, fun=fun)

        # Create a session
        with self.test_session() as sess:
            # Initializing weights
            sess.run(sub_spn_init)
            sess.run(spn_init)

            # Generate input feed
            feed = np.array(
                list(
                    itertools.product(range(num_vals),
                                      repeat=(num_inputs * num_vars))))
            batch_size = feed.shape[0]
            feed_dict = {}
            for inp, f in zip(inputs, np.split(feed, num_inputs, axis=1)):
                feed_dict[inp] = f

            # Compute all values and paths of sub-SPN
            sub_spn_out = sess.run(sub_spn_v, feed_dict=feed_dict)
            sub_spn_out_log = sess.run(tf.exp(sub_spn_v_log),
                                       feed_dict=feed_dict)
            sub_spn_out_path = sess.run(sub_spn_path, feed_dict=feed_dict)
            sub_spn_out_path_log = sess.run(sub_spn_path_log,
                                            feed_dict=feed_dict)

            # Compute all values and paths of the complete SPN
            spn_out = sess.run(spn_v, feed_dict=feed_dict)
            spn_out_log = sess.run(tf.exp(spn_v_log), feed_dict=feed_dict)
            spn_out_path = sess.run(spn_path, feed_dict=feed_dict)
            spn_out_path_log = sess.run(spn_path_log, feed_dict=feed_dict)

            # Test if partition function of the sub-SPN and of the
            # complete SPN is 1.0
            self.assertAlmostEqual(sub_spn_out.sum(), 1.0, places=6)
            self.assertAlmostEqual(sub_spn_out_log.sum(), 1.0, places=6)
            self.assertAlmostEqual(spn_out.sum(), 1.0, places=6)
            self.assertAlmostEqual(spn_out_log.sum(), 1.0, places=6)

            # Test if the sum of counts for each value of each variable
            # (6 variables, with 2 values each) = batch-size / num-vals
            self.assertEqual(
                np.sum(np.hstack(sub_spn_out_path), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(sub_spn_out_path_log), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(spn_out_path), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(spn_out_path_log), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)