Exemplo n.º 1
0
    def test_kernel(self, seed=1, n=500):
        print("test kernel")
        torch.random.manual_seed(seed)
        # input data
        x = torch.linspace(-10, 10, 201).view(-1, 1)

        nf = 2.0 * torch.mean(x[1:] - x[:-1])
        print('nyquist frequency: ', nf)
        # generating kernel values
        kernel = RBFKernel()
        kernel._set_lengthscale(3.)
        k = kernel(x, torch.zeros(1, 1)).evaluate()
        k = k.detach()

        # extracting approximate spectral density
        omega = torch.linspace(0, 1. / nf, n)
        tau = x.data

        s = torch.zeros(n)
        for ii in range(n):
            s[ii] = torch.dot(
                k.data.squeeze(),
                torch.cos(2 * math.pi * tau.squeeze() * omega[ii]))

        # reconstructing kernel using mean
        kernel_rec = SpectralGPKernel(integration='U',
                                      num_locs=n,
                                      omega_max=1. / nf)

        # check that the generated omegas are computed properly
        self.assertEqual((omega - kernel_rec.omega.squeeze()).norm().item(),
                         0.)

        kernel_output = kernel_rec.compute_kernel_values(
            tau, s, integration='U').view(-1)

        # fig, ax = plt.subplots(nrows=1, ncols=2)
        # ax[0].plot(x.numpy(), kernel_output.numpy(), label='Integration Output')
        # ax[0].plot(x.numpy(), k.numpy(), label='True K')
        # ax[0].legend()
        # ax[0].set_xlabel('x')
        # ax[0].set_ylabel('K(x,0)')

        # true_s = (2*math.pi*kernel.lengthscale**2)**(0.5) * torch.exp(-2*(math.pi * kernel.lengthscale * omega)**2)
        # ax[1].plot(omega.numpy(), s.numpy() * (0.5 * nf.numpy()), label = 'DTFT')
        # ax[1].plot(omega.numpy(), true_s.squeeze(0).detach().numpy(), label = 'True')
        # ax[1].legend()
        # ax[1].set_xlabel('omega')
        # ax[1].set_ylabel('S(omega)')
        # plt.show()

        relative_error = (kernel_output.view(-1) -
                          k.view(-1)).norm() / k.norm()
        print('relative error: ', relative_error)
        self.assertLess(relative_error, 1e-3)
Exemplo n.º 2
0
    def test_kernel_forwards(self, seed=1, n=500):
        print("test kernel")
        torch.random.manual_seed(seed)
        # input data
        x = torch.linspace(-100, 100, 201).view(-1, 1)

        # generating kernel values
        kernel = RBFKernel()
        kernel._set_lengthscale(3.)
        k = kernel(x, torch.zeros(1, 1)).evaluate()
        k = k.detach()

        # reconstructing kernel using mean
        kernel_rec = SpectralGPKernel(integration='U', num_locs=n)