We will assume one observation at node x4. The program generates samples and
then illustrates how to compute the needed kernel matrices and pre-computed
solutions to the linear systems to be solved by Kernel-BP. All matrices are
stored in separate files. In addition, a graph definition file is created, which
can be loaded by the provided line reader of our Graphlab Kernel-BP
implementation. See documentation for details.
"""

from numpy.random import randint
from src.GaussianKernel import GaussianKernel
from src.PrecomputeDenseMatrixKernelBP import PrecomputeDenseMatrixKernelBP
from src.ToyModel import ToyModel

model=ToyModel()
graph=model.get_moralised_graph()

# one observation at node 4
observations={4:0.0}

# directed edges for kernel BP implementation
edges=model.extract_edges(observations)

print "graph:", graph
print "observations:", observations
print "edges:", edges

# sample data, random number of samples for each edge
n_min=5
n_max=6
data={} 
    def fixed_data(self):
        """
        Uses some fixed data from the ToyModel and pre-computes all matrices.
        Results are asserted to be correct (against Matlab implementation) and
        output files can be used to test the KernelBP implementation.
        """
        model = ToyModel()
        graph = model.get_moralised_graph()

        # one observation at node 4
        observations = {4: 0.0}

        # directed edges for kernel BP implementation
        edges = model.extract_edges(observations)

        print "graph:", graph
        print "observations:", observations
        print "edges:", edges

        # we sample the data jointly, so edges will share data along vertices
        joint_data = {}
        joint_data[1] = [-0.274722354853981, 0.044011207316815, 0.073737451640458]
        joint_data[2] = [-0.173264814517908, 0.213918664844409, 0.123246012188621]
        joint_data[3] = [-0.348879413536605, -0.081766464397055, -0.117171083361484]
        joint_data[4] = [-0.014012058355118, -0.145789276405117, -0.317649695308685]
        joint_data[5] = [-0.291794859908481, 0.260902212951398, -0.276258182225143]

        # generate data in format that works for the dense matrix class, i.e., a pair of
        # points for every edge
        data = {}
        for edge in edges:
            # only sample once per undirected edge
            inverse_edge = (edge[1], edge[0])
            data[edge] = (joint_data[edge[0]], joint_data[edge[1]])
            data[inverse_edge] = (joint_data[edge[1]], joint_data[edge[0]])

        # Gaussian kernel used in matlab files
        kernel = GaussianKernel(sigma=sqrt(0.15))

        # use the example class for dense matrix data that can be stored in memory

        precomputer = PrecomputeDenseMatrixKernelBP(
            graph, edges, data, observations, kernel, reg_lambda=0.1, output_filename=self.output_filename
        )

        precomputer.precompute()

        # go through all the files and make sure they contain the correct matrices

        # files created by matlab implementation
        filenames = [
            "1->2->3_non_obs_kernel.txt",
            "1->2->4_non_obs_kernel.txt",
            "1->3->2_non_obs_kernel.txt",
            "1->3->4_non_obs_kernel.txt",
            "1->3->5_non_obs_kernel.txt",
            "2->1->3_non_obs_kernel.txt",
            "2->3->1_non_obs_kernel.txt",
            "2->3->4_non_obs_kernel.txt",
            "2->3->5_non_obs_kernel.txt",
            "2->4->3_non_obs_kernel.txt",
            "3->1->2_non_obs_kernel.txt",
            "3->2->1_non_obs_kernel.txt",
            "3->2->4_non_obs_kernel.txt",
            "3->4->2_non_obs_kernel.txt",
            "4->2->1_non_obs_kernel.txt",
            "4->2->3_non_obs_kernel.txt",
            "4->3->1_non_obs_kernel.txt",
            "4->3->2_non_obs_kernel.txt",
            "4->3->5_non_obs_kernel.txt",
            "5->3->1_non_obs_kernel.txt",
            "5->3->2_non_obs_kernel.txt",
            "5->3->4_non_obs_kernel.txt",
            "3->4_obs_kernel.txt",
            "2->4_obs_kernel.txt",
            "1->2_L_s.txt",
            "1->3_L_s.txt",
            "2->1_L_s.txt",
            "2->3_L_s.txt",
            "2->4_L_s.txt",
            "2->4_L_t.txt",
            "3->1_L_s.txt",
            "3->2_L_s.txt",
            "3->4_L_s.txt",
            "3->4_L_t.txt",
            "3->5_L_s.txt",
            "5->3_L_s.txt",
        ]

        # from matlab implementation
        matrices = {
            "2->1->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.712741, 0.667145], [0.712741, 1.000000, 0.997059], [0.667145, 0.997059, 1.000000]]
            ),
            "3->1->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.712741, 0.667145], [0.712741, 1.000000, 0.997059], [0.667145, 0.997059, 1.000000]]
            ),
            "1->2->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "1->2->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "3->2->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "3->2->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "4->2->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "4->2->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "1->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "1->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "1->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->4->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.943759, 0.735416], [0.943759, 1.000000, 0.906238], [0.735416, 0.906238, 1.000000]]
            ),
            "3->4->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.943759, 0.735416], [0.943759, 1.000000, 0.906238], [0.735416, 0.906238, 1.000000]]
            ),
            "2->4_obs_kernel.txt": asarray([[0.999346], [0.931603], [0.714382]]),
            "2->4_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "2->4_L_t.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.899839, 0.538785, 0.000000], [0.701191, 0.510924, 0.589311]]
            ),
            "3->4_obs_kernel.txt": asarray([[0.999346], [0.931603], [0.714382]]),
            "3->4_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "3->4_L_t.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.899839, 0.538785, 0.000000], [0.701191, 0.510924, 0.589311]]
            ),
            "2->1_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "3->1_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "1->2_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.679572, 0.798863, 0.000000], [0.636098, 0.706985, 0.442211]]
            ),
            "3->2_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "1->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.679572, 0.798863, 0.000000], [0.636098, 0.706985, 0.442211]]
            ),
            "2->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "5->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.344417, 0.990645, 0.000000], [0.952696, 0.054589, 0.435191]]
            ),
            "3->5_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
        }

        assert len(filenames) == len(matrices)

        for filename in filenames:
            self.assertTrue(self.assert_file_matrix(self.output_folder + filename, matrices[filename]))