예제 #1
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1)
        indices, values = Interpolation().interpolate(grid, x)
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(1).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = gpytorch.utils.interpolation.left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(test._utils.approx_equal(interp_func_x, test_func_x))
예제 #2
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0)
        indices, values = Interpolation().interpolate(Variable(grid), Variable(x))
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(0).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = utils.left_interp(indices.data, values.data, test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(utils.approx_equal(interp_func_x, test_func_x))
예제 #3
def test_interpolation():
    x = torch.linspace(0.01, 1, 100).unsqueeze(1)
    grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0)
    indices, values = Interpolation().interpolate(grid, x)
    test_func_grid = grid.squeeze(0).pow(2)
    test_func_x = x.pow(2).squeeze(-1)

    interp_func_x = utils.left_interp(indices, values,

    assert utils.approx_equal(interp_func_x, test_func_x)
예제 #4
def test_interpolation():
    x = torch.linspace(0.01, 1, 100)
    grid = torch.linspace(-0.05, 1.05, 50)
    J, C = Interpolation().interpolate(grid, x)
    W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
    test_func_grid = grid.pow(2)
    test_func_x = x.pow(2)

    interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()

    assert all(
        torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
    def forward(self, x1, x2, **kwargs):
        n, d = x1.size()
        m, _ = x2.size()

        if d > 1:
            raise RuntimeError(' '.join([
                'The grid interpolation kernel can only be applied to inputs of a single dimension at this time \
                until Kronecker structure is implemented.'

        if self.grid is None:
            raise RuntimeError(' '.join([
                'This GridInterpolationKernel has no grid. Call initialize_interpolation_grid \
                 on a GPModel first.'

        both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[0]
        both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[0]

        if both_min < self.grid_bounds[0] or both_max > self.grid_bounds[1]:
            # Out of bounds data is still ok if we are specifically computing kernel values for grid entries.
            if torch.abs(both_min - self.grid[0].data)[0] > 1e-7 or torch.abs(
                    both_max - self.grid[-1].data)[0] > 1e-7:
                raise RuntimeError(
                    'Received data that was out of bounds for the specified grid. \
                                    Grid bounds were ({}, {}), but min = {}, max = {}'
                    .format(self.grid_bounds[0], self.grid_bounds[1], both_min,

        J1, C1 = Interpolation().interpolate(self.grid.data, x1.data.squeeze())
        J2, C2 = Interpolation().interpolate(self.grid.data, x2.data.squeeze())

        k_UU = self.base_kernel_module(self.grid[0], self.grid,

        K_XX = ToeplitzLazyVariable(k_UU, J1, C1, J2, C2)

        return K_XX
예제 #6
    def test_multidim_interpolation(self):
        x = torch.Tensor([
            [0.25, 0.45, 0.65, 0.85],
            [0.35, 0.375, 0.4, 0.425],
            [0.45, 0.5, 0.55, 0.6],
        grid = torch.linspace(0., 1., 11).unsqueeze(0).repeat(3, 1)

        indices, values = Interpolation().interpolate(Variable(grid),

        actual_indices = torch.cat([
                    146, 147, 148, 149, 157, 158, 159, 160, 168, 169, 170, 171,
                    389, 390, 391, 392, 400, 401, 402, 403, 411, 412, 413, 414,
                    642, 643, 644, 645, 653, 654, 655, 656, 664, 665, 666, 667,
                    885, 886, 887, 888, 896, 897, 898, 899, 907, 908, 909, 910,
                    180, 181, 182, 267, 268, 269, 270, 278, 279, 280, 281, 289,
                    423, 424, 425, 510, 511, 512, 513, 521, 522, 523, 524, 532,
                    676, 677, 678, 763, 764, 765, 766, 774, 775, 776, 777, 785,
                    919, 920, 921, 1006, 1007, 1008, 1009, 1017, 1018, 1019,
                    1020, 1028, 1029
                    291, 292, 300, 301, 302, 303, 388, 389, 390, 391, 399, 400,
                    534, 535, 543, 544, 545, 546, 631, 632, 633, 634, 642, 643,
                    787, 788, 796, 797, 798, 799, 884, 885, 886, 887, 895, 896,
                    1030, 1031, 1039, 1040, 1041, 1042, 1127, 1128, 1129, 1130,
                    1138, 1139, 1140
                    402, 410, 411, 412, 413, 421, 422, 423, 424, 509, 510, 511,
                    645, 653, 654, 655, 656, 664, 665, 666, 667, 752, 753, 754,
                    898, 906, 907, 908, 909, 917, 918, 919, 920, 1005, 1006,
                    1007, 1008
                    1141, 1149, 1150, 1151, 1152, 1160, 1161, 1162, 1163, 1248,
                    1249, 1250, 1251
                [520, 521, 522, 523, 531, 532, 533, 534, 542, 543, 544, 545],
                [763, 764, 765, 766, 774, 775, 776, 777, 785, 786, 787, 788],
                    1016, 1017, 1018, 1019, 1027, 1028, 1029, 1030, 1038, 1039,
                    1040, 1041
                    1259, 1260, 1261, 1262, 1270, 1271, 1272, 1273, 1281, 1282,
                    1283, 1284
        ], 1)
        self.assertTrue(utils.approx_equal(indices.data, actual_indices))

        actual_values = torch.cat([
                    -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198, -0.0198,
                    0.0022, 0.0022, -0.0198
                    0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0142, -0.0000,
                    -0.0000, -0.0000, -0.0542
                    0.0000, -0.0000, -0.0000, 0.0000, 0.0039, -0.0352, -0.0352,
                    0.0039, 0.0000, -0.0000
                    0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0542, -0.0000,
                    -0.0000, -0.0000, -0.0142
                    -0.0198, 0.0022, -0.0002, 0.0022, 0.0022, -0.0002, 0.0022,
                    -0.0198, -0.0198, 0.0022
                    -0.0000, -0.0000, 0.0000, 0.0044, 0.0000, 0.0000, -0.0000,
                    -0.0132, -0.0000, -0.0000
                    -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, -0.0000,
                    0.0000, 0.0000, -0.0000
                    -0.0000, -0.0000, 0.0000, 0.0015, 0.0000, 0.0000, -0.0000,
                    -0.0396, -0.0000, -0.0000
                    -0.0198, 0.1780, 0.1780, -0.0198, -0.0198, 0.1780, 0.1780,
                    -0.0198, 0.0022, -0.0198
                    0.0000, 0.1274, 0.0000, 0.0000, 0.0000, 0.4878, 0.0000,
                    0.0000, -0.0000, -0.0396
                    -0.0352, 0.3164, 0.3164, -0.0352, -0.0000, 0.0000, 0.0000,
                    -0.0000, -0.0000, 0.0000
                    0.0000, 0.4878, 0.0000, 0.0000, 0.0000, 0.1274, 0.0000,
                    0.0000, -0.0000, -0.0132
                    -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022, -0.0198,
                    0.1780, 0.1780, -0.0198
                    -0.0000, -0.0000, -0.0000, -0.0132, -0.0000, -0.0000,
                    0.0000, 0.1274, 0.0000, 0.0000
                    0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, -0.0352,
                    0.3164, 0.3164, -0.0352
                    -0.0000, -0.0000, -0.0000, -0.0396, -0.0000, -0.0000,
                    0.0000, 0.4878, 0.0000, 0.0000
                    -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198, -0.0198,
                    0.0022, -0.0002, 0.0022
                    0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396, -0.0000,
                    -0.0000, 0.0000, 0.0015
                    -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, 0.0000,
                    -0.0000, 0.0000, -0.0000
                    0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132, -0.0000,
                    -0.0000, 0.0000, 0.0044
                    0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022, 0.0022,
                    -0.0198, -0.0198, 0.0022
                    0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000,
                    -0.0000, -0.0542, -0.0000, -0.0000
                    -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039, 0.0000,
                    -0.0000, -0.0000, 0.0000
                    0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000,
                    -0.0000, -0.0142, -0.0000, -0.0000
                [-0.0002, 0.0022, 0.0022, -0.0002],
                [0.0000, 0.0044, 0.0000, 0.0000],
                [0.0000, -0.0000, -0.0000, 0.0000],
                [0.0000, 0.0015, 0.0000, 0.0000],
        ], 1)
        self.assertTrue(utils.approx_equal(values.data, actual_values))
    def _compute_grid(self, x1, x2):
        if not self.has_grid:
            raise RuntimeError(
                'GridInterpolationKernel requires setting the interpolation grid'

        n, d = x1.size()
        m, _ = x2.size()

        if d > 1:
            Js1 = x1.data.new(d, len(x1.data), 4).zero_().long()
            Cs1 = x1.data.new(d, len(x1.data), 4).zero_()
            Js2 = x1.data.new(d, len(x1.data), 4).zero_().long()
            Cs2 = x1.data.new(d, len(x1.data), 4).zero_()
            for i in range(d):
                both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[i]
                both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[i]
                if both_min < self.grid_bounds[i][
                        0] or both_max > self.grid_bounds[i][1]:
                    # Out of bounds data is still ok if we are specifically computing kernel values for grid entries.
                    if math.fabs(both_min - self.grid[i, 0]) > 1e-7:
                        raise RuntimeError(
                            'Received data that was out of bounds for the specified grid. \
                                            Grid bounds were ({}, {}), but min = {}, \
                                            max = {}'.format(
                                self.grid_bounds[i][0], self.grid_bounds[i][1],
                                both_min, both_max))
                    elif math.fabs(both_max - self.grid[i, -1]) > 1e-7:
                        raise RuntimeError(
                            'Received data that was out of bounds for the specified grid. \
                                            Grid bounds were ({}, {}), but min = {}, \
                                            max = {}'.format(
                                self.grid_bounds[i][0], self.grid_bounds[i][1],
                                both_min, both_max))
                Js1[i], Cs1[i] = Interpolation().interpolate(
                    self.grid[i], x1.data[:, i])
                Js2[i], Cs2[i] = Interpolation().interpolate(
                    self.grid[i], x2.data[:, i])
            return Js1, Cs1, Js2, Cs2

        both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[0]
        both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[0]

        if both_min < self.grid_bounds[0][0] or both_max > self.grid_bounds[0][
            # Out of bounds data is still ok if we are specifically computing kernel values for grid entries.
            if math.fabs(both_min - self.grid[0, 0]) > 1e-7:
                raise RuntimeError(
                    'Received data that was out of bounds for the specified grid. \
                                    Grid bounds were ({}, {}), but min = {}, max = {}'
                    .format(self.grid_bounds[0][0], self.grid_bounds[0][1],
                            both_min, both_max))
            elif math.fabs(both_max - self.grid[0, -1]) > 1e-7:
                raise RuntimeError(
                    'Received data that was out of bounds for the specified grid. \
                                    Grid bounds were ({}, {}), but min = {}, max = {}'
                    .format(self.grid_bounds[0][0], self.grid_bounds[0][1],
                            both_min, both_max))
        J1, C1 = Interpolation().interpolate(self.grid[0], x1.data.squeeze())
        J2, C2 = Interpolation().interpolate(self.grid[0], x2.data.squeeze())
        return J1, C1, J2, C2