Exemplo n.º 1
0
    def perm_products(inputs,
                      num_inputs,
                      num_input_cols,
                      num_prods,
                      inf_type,
                      indices=None,
                      log=False,
                      output=None):
        if indices is not None:
            # Create inputs list with indices
            inputs = [[(inp, ind) for inp, ind in zip(inps, inds)]
                      for inps, inds in zip(inputs, indices)]

        if isinstance(inputs,
                      list):  # Is a list of ContVars inputs - Multiple inputs
            # Generate 'len(inputs)' PermProducts nodes, modeling 'n_prods' products
            # within each
            p = [spn.PermProducts(*inps) for inps in inputs]
        else:  # Is a single input of type ContVars - A single input
            num_inputs_array = np.array(num_inputs)
            num_input_cols_array = np.array(num_input_cols)
            num_cols = num_input_cols[0]
            num_vars = int(np.sum(num_inputs_array * num_input_cols_array))

            indices_list = [
                list(range(i, i + num_cols))
                for i in range(0, num_vars, num_cols)
            ]
            num_inputs_cumsum = np.cumsum(num_inputs_array).tolist()
            num_inputs_cumsum.insert(0, 0)

            inputs_list = [[(inputs, inds)
                            for inds in indices_list[start:stop]]
                           for start, stop in zip(num_inputs_cumsum[:-1],
                                                  num_inputs_cumsum[1:])]

            # Generate 'len(inputs)' PermProducts nodes, modeling 'n_prods'
            # products within each, and inputs for each node emination from a
            # commoninput source
            p = [spn.PermProducts(*inps) for inps in inputs_list]

        # Connect all PermProducts nodes to a single root Sum node and generate
        # its weights
        root = spn.Sum(*p)
        root.generate_weights()

        if log:
            value_op = root.get_log_value(inference_type=inf_type)
        else:
            value_op = root.get_value(inference_type=inf_type)

        return spn.initialize_weights(root), value_op
Exemplo n.º 2
0
    def poons_multi(inputs,
                    num_vals,
                    num_mixtures,
                    num_subsets,
                    inf_type,
                    log=False,
                    output=None):

        # Build a POON-like network with multi-op nodes
        subsets = [
            spn.ParSums((inputs, list(range(i * num_vals,
                                            (i + 1) * num_vals))),
                        num_sums=num_mixtures) for i in range(num_subsets)
        ]
        products = spn.PermProducts(*subsets)
        root = spn.Sum(products, name="root")

        # Generate dense SPN and all weights in the network
        spn.generate_weights(root)

        # Generate path ops based on inf_type and log
        if log:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type, log=True)
        else:
            mpe_path_gen = spn.MPEPath(value_inference_type=inf_type,
                                       log=False)

        mpe_path_gen.get_mpe_path(root)
        path_ops = [
            mpe_path_gen.counts[inp]
            for inp in (inputs if isinstance(inputs, list) else [inputs])
        ]
        return root, spn.initialize_weights(root), path_ops
Exemplo n.º 3
0
        def test(counts, inputs, feed, output):
            with self.subTest(counts=counts, inputs=inputs, feed=feed):
                p = spn.PermProducts(*inputs)
                op = p._compute_log_mpe_path(
                    tf.identity(counts), *[i[0].get_value() for i in inputs])
                with self.test_session() as sess:
                    out = sess.run(op, feed_dict=feed)

                for o, t in zip(out, output):
                    np.testing.assert_array_almost_equal(
                        o, np.array(t, dtype=spn.conf.dtype.as_numpy_dtype()))
Exemplo n.º 4
0
    def poons_multi(inputs, num_vals, num_mixtures, num_subsets, inf_type,
                    log=False, output=None):

        # Build a POON-like network with multi-op nodes
        subsets = [spn.ParSums((inputs, list(range(i*num_vals, (i+1)*num_vals))),
                               num_sums=num_mixtures) for i in range(num_subsets)]
        products = spn.PermProducts(*subsets)
        root = spn.Sum(products, name="root")

        # Generate dense SPN and all weights in the network
        spn.generate_weights(root)

        # Generate value ops based on inf_type and log
        if log:
            value_op = root.get_log_value(inference_type=inf_type)
        else:
            value_op = root.get_value(inference_type=inf_type)

        return root, spn.initialize_weights(root), value_op
Exemplo n.º 5
0
 def test(inputs, feed, output):
     with self.subTest(inputs=inputs, feed=feed):
         n = spn.PermProducts(*inputs)
         op = n.get_value(spn.InferenceType.MARGINAL)
         op_log = n.get_log_value(spn.InferenceType.MARGINAL)
         op_mpe = n.get_value(spn.InferenceType.MPE)
         op_log_mpe = n.get_log_value(spn.InferenceType.MPE)
         with self.test_session() as sess:
             out = sess.run(op, feed_dict=feed)
             out_log = sess.run(tf.exp(op_log), feed_dict=feed)
             out_mpe = sess.run(op_mpe, feed_dict=feed)
             out_log_mpe = sess.run(tf.exp(op_log_mpe), feed_dict=feed)
         np.testing.assert_array_almost_equal(
             out, np.array(output,
                           dtype=spn.conf.dtype.as_numpy_dtype()))
         np.testing.assert_array_almost_equal(
             out_log,
             np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
         np.testing.assert_array_almost_equal(
             out_mpe,
             np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
         np.testing.assert_array_almost_equal(
             out_log_mpe,
             np.array(output, dtype=spn.conf.dtype.as_numpy_dtype()))
Exemplo n.º 6
0
 def test_compute_valid(self):
     """Calculating validity of PermProducts"""
     v12 = spn.IVs(num_vars=2, num_vals=3)
     v345 = spn.IVs(num_vars=3, num_vals=3)
     v678 = spn.ContVars(num_vars=3)
     v910 = spn.ContVars(num_vars=2)
     p1 = spn.PermProducts((v12, [0, 1]), (v12, [4, 5]))
     p2 = spn.PermProducts((v12, [3, 5]), (v345, [0, 1, 2]))
     p3 = spn.PermProducts((v345, [0, 1, 2]), (v345, [3, 4, 5]),
                           (v345, [6, 7, 8]))
     p4 = spn.PermProducts((v345, [6, 8]), (v678, [0, 1]))
     p5 = spn.PermProducts((v678, [1]), v910)
     p6 = spn.PermProducts(v678, v910)
     p7 = spn.PermProducts((v678, [0, 1, 2]))
     p8 = spn.PermProducts((v910, [0]), (v910, [1]))
     self.assertTrue(p1.is_valid())
     self.assertTrue(p2.is_valid())
     self.assertTrue(p3.is_valid())
     self.assertTrue(p4.is_valid())
     self.assertTrue(p5.is_valid())
     self.assertTrue(p6.is_valid())
     self.assertTrue(p7.is_valid())
     self.assertTrue(p8.is_valid())
     p9 = spn.PermProducts((v12, [0, 1]), (v12, [1, 2]))
     p10 = spn.PermProducts((v12, [3, 4, 5]), (v345, [0]),
                            (v345, [0, 1, 2]))
     p11 = spn.PermProducts((v345, [3, 5]), (v678, [0]), (v678, [0]))
     p12 = spn.PermProducts((v910, [1]), (v910, [1]))
     p13 = spn.PermProducts(v910, v910)
     p14 = spn.PermProducts((v12, [0]), (v12, [1]))
     self.assertFalse(p9.is_valid())
     self.assertFalse(p10.is_valid())
     self.assertFalse(p11.is_valid())
     self.assertFalse(p12.is_valid())
     self.assertFalse(p13.is_valid())
     self.assertEqual(p14.num_prods, 1)
     self.assertFalse(p14.is_valid())
Exemplo n.º 7
0
 def test_comput_scope(self):
     """Calculating scope of PermProducts"""
     # Create graph
     v12 = spn.IVs(num_vars=2, num_vals=4, name="V12")
     v34 = spn.ContVars(num_vars=2, name="V34")
     s1 = spn.Sum((v12, [0, 1, 2, 3]), name="S1")
     s1.generate_ivs()
     s2 = spn.Sum((v12, [4, 5, 6, 7]), name="S2")
     p1 = spn.Product((v12, [0, 7]), name="P1")
     p2 = spn.Product((v12, [3, 4]), name="P2")
     p3 = spn.Product(v34, name="P3")
     n1 = spn.Concat(s1, s2, p3, name="N1")
     n2 = spn.Concat(p1, p2, name="N2")
     pp1 = spn.PermProducts(n1, n2, name="PP1")  # num_prods = 6
     pp2 = spn.PermProducts((n1, [0, 1]), (n2, [0]),
                            name="PP2")  # num_prods = 2
     pp3 = spn.PermProducts(n2, p3, name="PP3")  # num_prods = 2
     pp4 = spn.PermProducts(p2, p3, name="PP4")  # num_prods = 1
     pp5 = spn.PermProducts((n2, [0, 1]), name="PP5")  # num_prods = 1
     pp6 = spn.PermProducts(p3, name="PP6")  # num_prods = 1
     n3 = spn.Concat((pp1, [0, 2, 3]), pp2, pp4, name="N3")
     s3 = spn.Sum((pp1, [0, 2, 4]), (pp1, [1, 3, 5]),
                  pp2,
                  pp3, (pp4, 0),
                  pp5,
                  pp6,
                  name="S3")
     s3.generate_ivs()
     n4 = spn.Concat((pp3, [0, 1]), pp5, (pp6, 0), name="N4")
     pp7 = spn.PermProducts(n3, s3, n4, name="PP7")  # num_prods = 24
     pp8 = spn.PermProducts(n3, name="PP8")  # num_prods = 1
     pp9 = spn.PermProducts((n4, [0, 1, 2, 3]), name="PP9")  # num_prods = 1
     # Test
     self.assertListEqual(v12.get_scope(), [
         spn.Scope(v12, 0),
         spn.Scope(v12, 0),
         spn.Scope(v12, 0),
         spn.Scope(v12, 0),
         spn.Scope(v12, 1),
         spn.Scope(v12, 1),
         spn.Scope(v12, 1),
         spn.Scope(v12, 1)
     ])
     self.assertListEqual(
         v34.get_scope(),
         [spn.Scope(v34, 0), spn.Scope(v34, 1)])
     self.assertListEqual(s1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(s1.ivs.node, 0)])
     self.assertListEqual(s2.get_scope(), [spn.Scope(v12, 1)])
     self.assertListEqual(p1.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p2.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(p3.get_scope(),
                          [spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n1.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(s1.ivs.node, 0),
         spn.Scope(v12, 1),
         spn.Scope(v34, 0) | spn.Scope(v34, 1)
     ])
     self.assertListEqual(n2.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1)
     ])
     self.assertListEqual(pp1.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.ivs.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.ivs.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1)
     ])
     self.assertListEqual(pp2.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.ivs.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1)
     ])
     self.assertListEqual(pp3.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1)
     ])
     self.assertListEqual(pp4.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1)
     ])
     self.assertListEqual(pp5.get_scope(),
                          [spn.Scope(v12, 0) | spn.Scope(v12, 1)])
     self.assertListEqual(pp6.get_scope(),
                          [spn.Scope(v34, 0) | spn.Scope(v34, 1)])
     self.assertListEqual(n3.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.ivs.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.ivs.node, 0),
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1)
     ])
     self.assertListEqual(s3.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s1.ivs.node, 0)
         | spn.Scope(s3.ivs.node, 0)
     ])
     self.assertListEqual(n4.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1),
         spn.Scope(v12, 0) | spn.Scope(v12, 1),
         spn.Scope(v34, 0) | spn.Scope(v34, 1)
     ])
     self.assertListEqual(pp7.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1) | spn.Scope(s1.ivs.node, 0)
         | spn.Scope(s3.ivs.node, 0)
     ] * 24)
     self.assertListEqual(pp8.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(s1.ivs.node, 0)
         | spn.Scope(v34, 0) | spn.Scope(v34, 1)
     ])
     self.assertListEqual(pp9.get_scope(), [
         spn.Scope(v12, 0) | spn.Scope(v12, 1) | spn.Scope(v34, 0)
         | spn.Scope(v34, 1)
     ])