NOTE: Outgoing edges from observed nodes are removed. 

We will assume one observation at node x4. The program generates samples and
then illustrates how to compute the needed kernel matrices and pre-computed
solutions to the linear systems to be solved by Kernel-BP. All matrices are
stored in separate files. In addition, a graph definition file is created, which
can be loaded by the provided line reader of our Graphlab Kernel-BP
implementation. See documentation for details.
"""

from numpy.random import randint
from src.GaussianKernel import GaussianKernel
from src.PrecomputeDenseMatrixKernelBP import PrecomputeDenseMatrixKernelBP
from src.ToyModel import ToyModel

model=ToyModel()
graph=model.get_moralised_graph()

# one observation at node 4
observations={4:0.0}

# directed edges for kernel BP implementation
edges=model.extract_edges(observations)

print "graph:", graph
print "observations:", observations
print "edges:", edges

# sample data, random number of samples for each edge
n_min=5
n_max=6
    def fixed_data(self):
        """
        Uses some fixed data from the ToyModel and pre-computes all matrices.
        Results are asserted to be correct (against Matlab implementation) and
        output files can be used to test the KernelBP implementation.
        """
        model = ToyModel()
        graph = model.get_moralised_graph()

        # one observation at node 4
        observations = {4: 0.0}

        # directed edges for kernel BP implementation
        edges = model.extract_edges(observations)

        print "graph:", graph
        print "observations:", observations
        print "edges:", edges

        # we sample the data jointly, so edges will share data along vertices
        joint_data = {}
        joint_data[1] = [-0.274722354853981, 0.044011207316815, 0.073737451640458]
        joint_data[2] = [-0.173264814517908, 0.213918664844409, 0.123246012188621]
        joint_data[3] = [-0.348879413536605, -0.081766464397055, -0.117171083361484]
        joint_data[4] = [-0.014012058355118, -0.145789276405117, -0.317649695308685]
        joint_data[5] = [-0.291794859908481, 0.260902212951398, -0.276258182225143]

        # generate data in format that works for the dense matrix class, i.e., a pair of
        # points for every edge
        data = {}
        for edge in edges:
            # only sample once per undirected edge
            inverse_edge = (edge[1], edge[0])
            data[edge] = (joint_data[edge[0]], joint_data[edge[1]])
            data[inverse_edge] = (joint_data[edge[1]], joint_data[edge[0]])

        # Gaussian kernel used in matlab files
        kernel = GaussianKernel(sigma=sqrt(0.15))

        # use the example class for dense matrix data that can be stored in memory

        precomputer = PrecomputeDenseMatrixKernelBP(
            graph, edges, data, observations, kernel, reg_lambda=0.1, output_filename=self.output_filename
        )

        precomputer.precompute()

        # go through all the files and make sure they contain the correct matrices

        # files created by matlab implementation
        filenames = [
            "1->2->3_non_obs_kernel.txt",
            "1->2->4_non_obs_kernel.txt",
            "1->3->2_non_obs_kernel.txt",
            "1->3->4_non_obs_kernel.txt",
            "1->3->5_non_obs_kernel.txt",
            "2->1->3_non_obs_kernel.txt",
            "2->3->1_non_obs_kernel.txt",
            "2->3->4_non_obs_kernel.txt",
            "2->3->5_non_obs_kernel.txt",
            "2->4->3_non_obs_kernel.txt",
            "3->1->2_non_obs_kernel.txt",
            "3->2->1_non_obs_kernel.txt",
            "3->2->4_non_obs_kernel.txt",
            "3->4->2_non_obs_kernel.txt",
            "4->2->1_non_obs_kernel.txt",
            "4->2->3_non_obs_kernel.txt",
            "4->3->1_non_obs_kernel.txt",
            "4->3->2_non_obs_kernel.txt",
            "4->3->5_non_obs_kernel.txt",
            "5->3->1_non_obs_kernel.txt",
            "5->3->2_non_obs_kernel.txt",
            "5->3->4_non_obs_kernel.txt",
            "3->4_obs_kernel.txt",
            "2->4_obs_kernel.txt",
            "1->2_L_s.txt",
            "1->3_L_s.txt",
            "2->1_L_s.txt",
            "2->3_L_s.txt",
            "2->4_L_s.txt",
            "2->4_L_t.txt",
            "3->1_L_s.txt",
            "3->2_L_s.txt",
            "3->4_L_s.txt",
            "3->4_L_t.txt",
            "3->5_L_s.txt",
            "5->3_L_s.txt",
        ]

        # from matlab implementation
        matrices = {
            "2->1->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.712741, 0.667145], [0.712741, 1.000000, 0.997059], [0.667145, 0.997059, 1.000000]]
            ),
            "3->1->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.712741, 0.667145], [0.712741, 1.000000, 0.997059], [0.667145, 0.997059, 1.000000]]
            ),
            "1->2->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "1->2->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "3->2->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "3->2->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "4->2->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "4->2->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "1->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "1->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "1->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->4->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.943759, 0.735416], [0.943759, 1.000000, 0.906238], [0.735416, 0.906238, 1.000000]]
            ),
            "3->4->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.943759, 0.735416], [0.943759, 1.000000, 0.906238], [0.735416, 0.906238, 1.000000]]
            ),
            "2->4_obs_kernel.txt": asarray([[0.999346], [0.931603], [0.714382]]),
            "2->4_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "2->4_L_t.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.899839, 0.538785, 0.000000], [0.701191, 0.510924, 0.589311]]
            ),
            "3->4_obs_kernel.txt": asarray([[0.999346], [0.931603], [0.714382]]),
            "3->4_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "3->4_L_t.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.899839, 0.538785, 0.000000], [0.701191, 0.510924, 0.589311]]
            ),
            "2->1_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "3->1_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "1->2_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.679572, 0.798863, 0.000000], [0.636098, 0.706985, 0.442211]]
            ),
            "3->2_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "1->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.679572, 0.798863, 0.000000], [0.636098, 0.706985, 0.442211]]
            ),
            "2->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "5->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.344417, 0.990645, 0.000000], [0.952696, 0.054589, 0.435191]]
            ),
            "3->5_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
        }

        assert len(filenames) == len(matrices)

        for filename in filenames:
            self.assertTrue(self.assert_file_matrix(self.output_folder + filename, matrices[filename]))