# directed edges for kernel BP implementation
edges=model.extract_edges(observations)

print "graph:", graph
print "observations:", observations
print "edges:", edges

# sample data, random number of samples for each edge
n_min=5
n_max=6
data={} 
for edge in edges:
    samples = model.sample_real(randint(n_min, n_max))
    
    # only sample once per undirected edge
    inverse_edge=(edge[1], edge[0])
    if (edge not in data and inverse_edge not in data):
        data1=samples[edge[0]]
        data2=samples[edge[1]]
        
        data[edge]=(data1,data2)
        data[inverse_edge]=(data2,data1)
        
# compute all (here Gaussian) kernels of node data at edges with themselves
kernel=GaussianKernel(sigma=1)

# use the example class for dense matrix data that can be stored in memory
precomputer=PrecomputeDenseMatrixKernelBP(graph, edges, data, observations, \
                                          kernel, reg_lambda=0.1, output_filename="graph/graph.txt")

precomputer.precompute()
    def fixed_data(self):
        """
        Uses some fixed data from the ToyModel and pre-computes all matrices.
        Results are asserted to be correct (against Matlab implementation) and
        output files can be used to test the KernelBP implementation.
        """
        model = ToyModel()
        graph = model.get_moralised_graph()

        # one observation at node 4
        observations = {4: 0.0}

        # directed edges for kernel BP implementation
        edges = model.extract_edges(observations)

        print "graph:", graph
        print "observations:", observations
        print "edges:", edges

        # we sample the data jointly, so edges will share data along vertices
        joint_data = {}
        joint_data[1] = [-0.274722354853981, 0.044011207316815, 0.073737451640458]
        joint_data[2] = [-0.173264814517908, 0.213918664844409, 0.123246012188621]
        joint_data[3] = [-0.348879413536605, -0.081766464397055, -0.117171083361484]
        joint_data[4] = [-0.014012058355118, -0.145789276405117, -0.317649695308685]
        joint_data[5] = [-0.291794859908481, 0.260902212951398, -0.276258182225143]

        # generate data in format that works for the dense matrix class, i.e., a pair of
        # points for every edge
        data = {}
        for edge in edges:
            # only sample once per undirected edge
            inverse_edge = (edge[1], edge[0])
            data[edge] = (joint_data[edge[0]], joint_data[edge[1]])
            data[inverse_edge] = (joint_data[edge[1]], joint_data[edge[0]])

        # Gaussian kernel used in matlab files
        kernel = GaussianKernel(sigma=sqrt(0.15))

        # use the example class for dense matrix data that can be stored in memory

        precomputer = PrecomputeDenseMatrixKernelBP(
            graph, edges, data, observations, kernel, reg_lambda=0.1, output_filename=self.output_filename
        )

        precomputer.precompute()

        # go through all the files and make sure they contain the correct matrices

        # files created by matlab implementation
        filenames = [
            "1->2->3_non_obs_kernel.txt",
            "1->2->4_non_obs_kernel.txt",
            "1->3->2_non_obs_kernel.txt",
            "1->3->4_non_obs_kernel.txt",
            "1->3->5_non_obs_kernel.txt",
            "2->1->3_non_obs_kernel.txt",
            "2->3->1_non_obs_kernel.txt",
            "2->3->4_non_obs_kernel.txt",
            "2->3->5_non_obs_kernel.txt",
            "2->4->3_non_obs_kernel.txt",
            "3->1->2_non_obs_kernel.txt",
            "3->2->1_non_obs_kernel.txt",
            "3->2->4_non_obs_kernel.txt",
            "3->4->2_non_obs_kernel.txt",
            "4->2->1_non_obs_kernel.txt",
            "4->2->3_non_obs_kernel.txt",
            "4->3->1_non_obs_kernel.txt",
            "4->3->2_non_obs_kernel.txt",
            "4->3->5_non_obs_kernel.txt",
            "5->3->1_non_obs_kernel.txt",
            "5->3->2_non_obs_kernel.txt",
            "5->3->4_non_obs_kernel.txt",
            "3->4_obs_kernel.txt",
            "2->4_obs_kernel.txt",
            "1->2_L_s.txt",
            "1->3_L_s.txt",
            "2->1_L_s.txt",
            "2->3_L_s.txt",
            "2->4_L_s.txt",
            "2->4_L_t.txt",
            "3->1_L_s.txt",
            "3->2_L_s.txt",
            "3->4_L_s.txt",
            "3->4_L_t.txt",
            "3->5_L_s.txt",
            "5->3_L_s.txt",
        ]

        # from matlab implementation
        matrices = {
            "2->1->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.712741, 0.667145], [0.712741, 1.000000, 0.997059], [0.667145, 0.997059, 1.000000]]
            ),
            "3->1->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.712741, 0.667145], [0.712741, 1.000000, 0.997059], [0.667145, 0.997059, 1.000000]]
            ),
            "1->2->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "1->2->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "3->2->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "3->2->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "4->2->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "4->2->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.606711, 0.745976], [0.606711, 1.000000, 0.972967], [0.745976, 0.972967, 1.000000]]
            ),
            "1->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "1->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "1->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "4->3->5_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->1_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "5->3->4_non_obs_kernel.txt": asarray(
                [[1.000000, 0.788336, 0.836137], [0.788336, 1.000000, 0.995830], [0.836137, 0.995830, 1.000000]]
            ),
            "2->4->3_non_obs_kernel.txt": asarray(
                [[1.000000, 0.943759, 0.735416], [0.943759, 1.000000, 0.906238], [0.735416, 0.906238, 1.000000]]
            ),
            "3->4->2_non_obs_kernel.txt": asarray(
                [[1.000000, 0.943759, 0.735416], [0.943759, 1.000000, 0.906238], [0.735416, 0.906238, 1.000000]]
            ),
            "2->4_obs_kernel.txt": asarray([[0.999346], [0.931603], [0.714382]]),
            "2->4_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "2->4_L_t.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.899839, 0.538785, 0.000000], [0.701191, 0.510924, 0.589311]]
            ),
            "3->4_obs_kernel.txt": asarray([[0.999346], [0.931603], [0.714382]]),
            "3->4_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "3->4_L_t.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.899839, 0.538785, 0.000000], [0.701191, 0.510924, 0.589311]]
            ),
            "2->1_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "3->1_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "1->2_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.679572, 0.798863, 0.000000], [0.636098, 0.706985, 0.442211]]
            ),
            "3->2_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
            "1->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.679572, 0.798863, 0.000000], [0.636098, 0.706985, 0.442211]]
            ),
            "2->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.578476, 0.874852, 0.000000], [0.711260, 0.641846, 0.426782]]
            ),
            "5->3_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.344417, 0.990645, 0.000000], [0.952696, 0.054589, 0.435191]]
            ),
            "3->5_L_s.txt": asarray(
                [[1.048809, 0.000000, 0.000000], [0.751649, 0.731453, 0.000000], [0.797226, 0.542204, 0.412852]]
            ),
        }

        assert len(filenames) == len(matrices)

        for filename in filenames:
            self.assertTrue(self.assert_file_matrix(self.output_folder + filename, matrices[filename]))