Ejemplo n.º 1
0
    def test_inference_results(self):
        np.random.seed(123)
        tf.set_random_seed(123)

        num_dims = 20

        rg = region_graph.RegionGraph(range(num_dims))
        for _ in range(0, 10):
            rg.random_split(2, 3)

        args = RAT_SPN.SpnArgs()
        args.normalized_sums = True
        spn = RAT_SPN.RatSpn(10, region_graph=rg, name="obj-spn", args=args)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        dummy_input = np.random.normal(0.0, 1.2, [10, num_dims])
        input_ph = tf.placeholder(tf.float32, [10, num_dims])
        output_tensor = spn.forward(input_ph)
        tf_output = sess.run(output_tensor, feed_dict={input_ph: dummy_input})

        output_nodes = spn.get_simple_spn(sess)
        simple_output = []
        for node in output_nodes:
            simple_output.append(Inference.likelihood(node, dummy_input))
        simple_output = np.stack(simple_output)
        deviation = simple_output / np.exp(tf_output)
        rel_error = np.abs(deviation - 1.0)
        # print(rel_error)

        self.assertTrue(np.all(rel_error < 1e-2))
Ejemplo n.º 2
0
            num_correct_batch = np.sum(max_idx == label_batch)
            num_correct += num_correct_batch

        acc = num_correct / (batch_size * batches_per_epoch)
        print(i, acc, cur_loss)


def softmax(x, axis=0):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)


if __name__ == "__main__":
    rg = region_graph.RegionGraph(range(28 * 28))
    # rg = region_graph.RegionGraph(range(3 * 3))
    for _ in range(0, 2):
        rg.random_split(2, 1)

    args = RAT_SPN.SpnArgs()
    args.normalized_sums = True
    args.num_sums = 2
    args.num_gauss = 2
    spn = RAT_SPN.RatSpn(10, region_graph=rg, name="obj-spn", args=args)
    print("num_params", spn.num_params())

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    (train_im, train_labels), _ = load_mnist()