コード例 #1
0
ファイル: test_interpolation.py プロジェクト: marses/gpytorch
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1)
        indices, values = Interpolation().interpolate(grid, x)
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(1).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = gpytorch.utils.interpolation.left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(test._utils.approx_equal(interp_func_x, test_func_x))
コード例 #2
0
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0)
        indices, values = Interpolation().interpolate(Variable(grid), Variable(x))
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(0).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = utils.left_interp(indices.data, values.data, test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(utils.approx_equal(interp_func_x, test_func_x))
コード例 #3
0
def test_interpolation():
    x = torch.linspace(0.01, 1, 100).unsqueeze(1)
    grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0)
    indices, values = Interpolation().interpolate(grid, x)
    indices.squeeze_(0)
    values.squeeze_(0)
    test_func_grid = grid.squeeze(0).pow(2)
    test_func_x = x.pow(2).squeeze(-1)

    interp_func_x = utils.left_interp(indices, values,
                                      test_func_grid.unsqueeze(1)).squeeze()

    assert utils.approx_equal(interp_func_x, test_func_x)
コード例 #4
0
def test_interpolation():
    x = torch.linspace(0.01, 1, 100)
    grid = torch.linspace(-0.05, 1.05, 50)
    J, C = Interpolation().interpolate(grid, x)
    W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid))
    test_func_grid = grid.pow(2)
    test_func_x = x.pow(2)

    interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze()

    assert all(
        torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
コード例 #5
0
    def forward(self, x1, x2, **kwargs):
        n, d = x1.size()
        m, _ = x2.size()

        if d > 1:
            raise RuntimeError(' '.join([
                'The grid interpolation kernel can only be applied to inputs of a single dimension at this time \
                until Kronecker structure is implemented.'
            ]))

        if self.grid is None:
            raise RuntimeError(' '.join([
                'This GridInterpolationKernel has no grid. Call initialize_interpolation_grid \
                 on a GPModel first.'
            ]))

        both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[0]
        both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[0]

        if both_min < self.grid_bounds[0] or both_max > self.grid_bounds[1]:
            # Out of bounds data is still ok if we are specifically computing kernel values for grid entries.
            if torch.abs(both_min - self.grid[0].data)[0] > 1e-7 or torch.abs(
                    both_max - self.grid[-1].data)[0] > 1e-7:
                raise RuntimeError(
                    'Received data that was out of bounds for the specified grid. \
                                    Grid bounds were ({}, {}), but min = {}, max = {}'
                    .format(self.grid_bounds[0], self.grid_bounds[1], both_min,
                            both_max))

        J1, C1 = Interpolation().interpolate(self.grid.data, x1.data.squeeze())
        J2, C2 = Interpolation().interpolate(self.grid.data, x2.data.squeeze())

        k_UU = self.base_kernel_module(self.grid[0], self.grid,
                                       **kwargs).squeeze()

        K_XX = ToeplitzLazyVariable(k_UU, J1, C1, J2, C2)

        return K_XX
コード例 #6
0
    def test_multidim_interpolation(self):
        x = torch.Tensor([
            [0.25, 0.45, 0.65, 0.85],
            [0.35, 0.375, 0.4, 0.425],
            [0.45, 0.5, 0.55, 0.6],
        ]).t().contiguous()
        grid = torch.linspace(0., 1., 11).unsqueeze(0).repeat(3, 1)

        indices, values = Interpolation().interpolate(Variable(grid),
                                                      Variable(x))

        actual_indices = torch.cat([
            torch.LongTensor([
                [
                    146, 147, 148, 149, 157, 158, 159, 160, 168, 169, 170, 171,
                    179
                ],
                [
                    389, 390, 391, 392, 400, 401, 402, 403, 411, 412, 413, 414,
                    422
                ],
                [
                    642, 643, 644, 645, 653, 654, 655, 656, 664, 665, 666, 667,
                    675
                ],
                [
                    885, 886, 887, 888, 896, 897, 898, 899, 907, 908, 909, 910,
                    918
                ],
            ]),
            torch.LongTensor([
                [
                    180, 181, 182, 267, 268, 269, 270, 278, 279, 280, 281, 289,
                    290
                ],
                [
                    423, 424, 425, 510, 511, 512, 513, 521, 522, 523, 524, 532,
                    533
                ],
                [
                    676, 677, 678, 763, 764, 765, 766, 774, 775, 776, 777, 785,
                    786
                ],
                [
                    919, 920, 921, 1006, 1007, 1008, 1009, 1017, 1018, 1019,
                    1020, 1028, 1029
                ],
            ]),
            torch.LongTensor([
                [
                    291, 292, 300, 301, 302, 303, 388, 389, 390, 391, 399, 400,
                    401
                ],
                [
                    534, 535, 543, 544, 545, 546, 631, 632, 633, 634, 642, 643,
                    644
                ],
                [
                    787, 788, 796, 797, 798, 799, 884, 885, 886, 887, 895, 896,
                    897
                ],
                [
                    1030, 1031, 1039, 1040, 1041, 1042, 1127, 1128, 1129, 1130,
                    1138, 1139, 1140
                ],
            ]),
            torch.LongTensor([
                [
                    402, 410, 411, 412, 413, 421, 422, 423, 424, 509, 510, 511,
                    512
                ],
                [
                    645, 653, 654, 655, 656, 664, 665, 666, 667, 752, 753, 754,
                    755
                ],
                [
                    898, 906, 907, 908, 909, 917, 918, 919, 920, 1005, 1006,
                    1007, 1008
                ],
                [
                    1141, 1149, 1150, 1151, 1152, 1160, 1161, 1162, 1163, 1248,
                    1249, 1250, 1251
                ],
            ]),
            torch.LongTensor([
                [520, 521, 522, 523, 531, 532, 533, 534, 542, 543, 544, 545],
                [763, 764, 765, 766, 774, 775, 776, 777, 785, 786, 787, 788],
                [
                    1016, 1017, 1018, 1019, 1027, 1028, 1029, 1030, 1038, 1039,
                    1040, 1041
                ],
                [
                    1259, 1260, 1261, 1262, 1270, 1271, 1272, 1273, 1281, 1282,
                    1283, 1284
                ],
            ])
        ], 1)
        self.assertTrue(utils.approx_equal(indices.data, actual_indices))

        actual_values = torch.cat([
            torch.Tensor([
                [
                    -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198, -0.0198,
                    0.0022, 0.0022, -0.0198
                ],
                [
                    0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0142, -0.0000,
                    -0.0000, -0.0000, -0.0542
                ],
                [
                    0.0000, -0.0000, -0.0000, 0.0000, 0.0039, -0.0352, -0.0352,
                    0.0039, 0.0000, -0.0000
                ],
                [
                    0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0542, -0.0000,
                    -0.0000, -0.0000, -0.0142
                ],
            ]),
            torch.Tensor([
                [
                    -0.0198, 0.0022, -0.0002, 0.0022, 0.0022, -0.0002, 0.0022,
                    -0.0198, -0.0198, 0.0022
                ],
                [
                    -0.0000, -0.0000, 0.0000, 0.0044, 0.0000, 0.0000, -0.0000,
                    -0.0132, -0.0000, -0.0000
                ],
                [
                    -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, -0.0000,
                    0.0000, 0.0000, -0.0000
                ],
                [
                    -0.0000, -0.0000, 0.0000, 0.0015, 0.0000, 0.0000, -0.0000,
                    -0.0396, -0.0000, -0.0000
                ],
            ]),
            torch.Tensor([
                [
                    -0.0198, 0.1780, 0.1780, -0.0198, -0.0198, 0.1780, 0.1780,
                    -0.0198, 0.0022, -0.0198
                ],
                [
                    0.0000, 0.1274, 0.0000, 0.0000, 0.0000, 0.4878, 0.0000,
                    0.0000, -0.0000, -0.0396
                ],
                [
                    -0.0352, 0.3164, 0.3164, -0.0352, -0.0000, 0.0000, 0.0000,
                    -0.0000, -0.0000, 0.0000
                ],
                [
                    0.0000, 0.4878, 0.0000, 0.0000, 0.0000, 0.1274, 0.0000,
                    0.0000, -0.0000, -0.0132
                ],
            ]),
            torch.Tensor([
                [
                    -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022, -0.0198,
                    0.1780, 0.1780, -0.0198
                ],
                [
                    -0.0000, -0.0000, -0.0000, -0.0132, -0.0000, -0.0000,
                    0.0000, 0.1274, 0.0000, 0.0000
                ],
                [
                    0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, -0.0352,
                    0.3164, 0.3164, -0.0352
                ],
                [
                    -0.0000, -0.0000, -0.0000, -0.0396, -0.0000, -0.0000,
                    0.0000, 0.4878, 0.0000, 0.0000
                ],
            ]),
            torch.Tensor([
                [
                    -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198, -0.0198,
                    0.0022, -0.0002, 0.0022
                ],
                [
                    0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396, -0.0000,
                    -0.0000, 0.0000, 0.0015
                ],
                [
                    -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, 0.0000,
                    -0.0000, 0.0000, -0.0000
                ],
                [
                    0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132, -0.0000,
                    -0.0000, 0.0000, 0.0044
                ],
            ]),
            torch.Tensor([
                [
                    0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022, 0.0022,
                    -0.0198, -0.0198, 0.0022
                ],
                [
                    0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000,
                    -0.0000, -0.0542, -0.0000, -0.0000
                ],
                [
                    -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039, 0.0000,
                    -0.0000, -0.0000, 0.0000
                ],
                [
                    0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000,
                    -0.0000, -0.0142, -0.0000, -0.0000
                ],
            ]),
            torch.Tensor([
                [-0.0002, 0.0022, 0.0022, -0.0002],
                [0.0000, 0.0044, 0.0000, 0.0000],
                [0.0000, -0.0000, -0.0000, 0.0000],
                [0.0000, 0.0015, 0.0000, 0.0000],
            ]),
        ], 1)
        self.assertTrue(utils.approx_equal(values.data, actual_values))
コード例 #7
0
    def _compute_grid(self, x1, x2):
        if not self.has_grid:
            raise RuntimeError(
                'GridInterpolationKernel requires setting the interpolation grid'
            )

        n, d = x1.size()
        m, _ = x2.size()

        if d > 1:
            Js1 = x1.data.new(d, len(x1.data), 4).zero_().long()
            Cs1 = x1.data.new(d, len(x1.data), 4).zero_()
            Js2 = x1.data.new(d, len(x1.data), 4).zero_().long()
            Cs2 = x1.data.new(d, len(x1.data), 4).zero_()
            for i in range(d):
                both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[i]
                both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[i]
                if both_min < self.grid_bounds[i][
                        0] or both_max > self.grid_bounds[i][1]:
                    # Out of bounds data is still ok if we are specifically computing kernel values for grid entries.
                    if math.fabs(both_min - self.grid[i, 0]) > 1e-7:
                        raise RuntimeError(
                            'Received data that was out of bounds for the specified grid. \
                                            Grid bounds were ({}, {}), but min = {}, \
                                            max = {}'.format(
                                self.grid_bounds[i][0], self.grid_bounds[i][1],
                                both_min, both_max))
                    elif math.fabs(both_max - self.grid[i, -1]) > 1e-7:
                        raise RuntimeError(
                            'Received data that was out of bounds for the specified grid. \
                                            Grid bounds were ({}, {}), but min = {}, \
                                            max = {}'.format(
                                self.grid_bounds[i][0], self.grid_bounds[i][1],
                                both_min, both_max))
                Js1[i], Cs1[i] = Interpolation().interpolate(
                    self.grid[i], x1.data[:, i])
                Js2[i], Cs2[i] = Interpolation().interpolate(
                    self.grid[i], x2.data[:, i])
            return Js1, Cs1, Js2, Cs2

        both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[0]
        both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[0]

        if both_min < self.grid_bounds[0][0] or both_max > self.grid_bounds[0][
                1]:
            # Out of bounds data is still ok if we are specifically computing kernel values for grid entries.
            if math.fabs(both_min - self.grid[0, 0]) > 1e-7:
                raise RuntimeError(
                    'Received data that was out of bounds for the specified grid. \
                                    Grid bounds were ({}, {}), but min = {}, max = {}'
                    .format(self.grid_bounds[0][0], self.grid_bounds[0][1],
                            both_min, both_max))
            elif math.fabs(both_max - self.grid[0, -1]) > 1e-7:
                raise RuntimeError(
                    'Received data that was out of bounds for the specified grid. \
                                    Grid bounds were ({}, {}), but min = {}, max = {}'
                    .format(self.grid_bounds[0][0], self.grid_bounds[0][1],
                            both_min, both_max))
        J1, C1 = Interpolation().interpolate(self.grid[0], x1.data.squeeze())
        J2, C2 = Interpolation().interpolate(self.grid[0], x2.data.squeeze())
        return J1, C1, J2, C2