def test_interpolation(self): x = torch.linspace(0.01, 1, 100).unsqueeze(1) grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1) indices, values = Interpolation().interpolate(grid, x) indices = indices.squeeze_(0) values = values.squeeze_(0) test_func_grid = grid.squeeze(1).pow(2) test_func_x = x.pow(2).squeeze(-1) interp_func_x = gpytorch.utils.interpolation.left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze() self.assertTrue(test._utils.approx_equal(interp_func_x, test_func_x))
def test_interpolation(self): x = torch.linspace(0.01, 1, 100).unsqueeze(1) grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0) indices, values = Interpolation().interpolate(Variable(grid), Variable(x)) indices = indices.squeeze_(0) values = values.squeeze_(0) test_func_grid = grid.squeeze(0).pow(2) test_func_x = x.pow(2).squeeze(-1) interp_func_x = utils.left_interp(indices.data, values.data, test_func_grid.unsqueeze(1)).squeeze() self.assertTrue(utils.approx_equal(interp_func_x, test_func_x))
def test_interpolation(): x = torch.linspace(0.01, 1, 100).unsqueeze(1) grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(0) indices, values = Interpolation().interpolate(grid, x) indices.squeeze_(0) values.squeeze_(0) test_func_grid = grid.squeeze(0).pow(2) test_func_x = x.pow(2).squeeze(-1) interp_func_x = utils.left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze() assert utils.approx_equal(interp_func_x, test_func_x)
def test_interpolation(): x = torch.linspace(0.01, 1, 100) grid = torch.linspace(-0.05, 1.05, 50) J, C = Interpolation().interpolate(grid, x) W = utils.toeplitz.index_coef_to_sparse(J, C, len(grid)) test_func_grid = grid.pow(2) test_func_x = x.pow(2) interp_func_x = torch.dsmm(W, test_func_grid.unsqueeze(1)).squeeze() assert all( torch.abs(interp_func_x - test_func_x) / (test_func_x + 1e-10) < 1e-5)
def forward(self, x1, x2, **kwargs): n, d = x1.size() m, _ = x2.size() if d > 1: raise RuntimeError(' '.join([ 'The grid interpolation kernel can only be applied to inputs of a single dimension at this time \ until Kronecker structure is implemented.' ])) if self.grid is None: raise RuntimeError(' '.join([ 'This GridInterpolationKernel has no grid. Call initialize_interpolation_grid \ on a GPModel first.' ])) both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[0] both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[0] if both_min < self.grid_bounds[0] or both_max > self.grid_bounds[1]: # Out of bounds data is still ok if we are specifically computing kernel values for grid entries. if torch.abs(both_min - self.grid[0].data)[0] > 1e-7 or torch.abs( both_max - self.grid[-1].data)[0] > 1e-7: raise RuntimeError( 'Received data that was out of bounds for the specified grid. \ Grid bounds were ({}, {}), but min = {}, max = {}' .format(self.grid_bounds[0], self.grid_bounds[1], both_min, both_max)) J1, C1 = Interpolation().interpolate(self.grid.data, x1.data.squeeze()) J2, C2 = Interpolation().interpolate(self.grid.data, x2.data.squeeze()) k_UU = self.base_kernel_module(self.grid[0], self.grid, **kwargs).squeeze() K_XX = ToeplitzLazyVariable(k_UU, J1, C1, J2, C2) return K_XX
def test_multidim_interpolation(self): x = torch.Tensor([ [0.25, 0.45, 0.65, 0.85], [0.35, 0.375, 0.4, 0.425], [0.45, 0.5, 0.55, 0.6], ]).t().contiguous() grid = torch.linspace(0., 1., 11).unsqueeze(0).repeat(3, 1) indices, values = Interpolation().interpolate(Variable(grid), Variable(x)) actual_indices = torch.cat([ torch.LongTensor([ [ 146, 147, 148, 149, 157, 158, 159, 160, 168, 169, 170, 171, 179 ], [ 389, 390, 391, 392, 400, 401, 402, 403, 411, 412, 413, 414, 422 ], [ 642, 643, 644, 645, 653, 654, 655, 656, 664, 665, 666, 667, 675 ], [ 885, 886, 887, 888, 896, 897, 898, 899, 907, 908, 909, 910, 918 ], ]), torch.LongTensor([ [ 180, 181, 182, 267, 268, 269, 270, 278, 279, 280, 281, 289, 290 ], [ 423, 424, 425, 510, 511, 512, 513, 521, 522, 523, 524, 532, 533 ], [ 676, 677, 678, 763, 764, 765, 766, 774, 775, 776, 777, 785, 786 ], [ 919, 920, 921, 1006, 1007, 1008, 1009, 1017, 1018, 1019, 1020, 1028, 1029 ], ]), torch.LongTensor([ [ 291, 292, 300, 301, 302, 303, 388, 389, 390, 391, 399, 400, 401 ], [ 534, 535, 543, 544, 545, 546, 631, 632, 633, 634, 642, 643, 644 ], [ 787, 788, 796, 797, 798, 799, 884, 885, 886, 887, 895, 896, 897 ], [ 1030, 1031, 1039, 1040, 1041, 1042, 1127, 1128, 1129, 1130, 1138, 1139, 1140 ], ]), torch.LongTensor([ [ 402, 410, 411, 412, 413, 421, 422, 423, 424, 509, 510, 511, 512 ], [ 645, 653, 654, 655, 656, 664, 665, 666, 667, 752, 753, 754, 755 ], [ 898, 906, 907, 908, 909, 917, 918, 919, 920, 1005, 1006, 1007, 1008 ], [ 1141, 1149, 1150, 1151, 1152, 1160, 1161, 1162, 1163, 1248, 1249, 1250, 1251 ], ]), torch.LongTensor([ [520, 521, 522, 523, 531, 532, 533, 534, 542, 543, 544, 545], [763, 764, 765, 766, 774, 775, 776, 777, 785, 786, 787, 788], [ 1016, 1017, 1018, 1019, 1027, 1028, 1029, 1030, 1038, 1039, 1040, 1041 ], [ 1259, 1260, 1261, 1262, 1270, 1271, 1272, 1273, 1281, 1282, 1283, 1284 ], ]) ], 1) self.assertTrue(utils.approx_equal(indices.data, actual_indices)) actual_values = torch.cat([ torch.Tensor([ [ -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022, 0.0022, -0.0198 ], [ 0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000, -0.0000, -0.0542 ], [ 0.0000, -0.0000, -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039, 0.0000, -0.0000 ], [ 0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000, -0.0000, -0.0142 ], ]), torch.Tensor([ [ -0.0198, 0.0022, -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022 ], [ -0.0000, -0.0000, 0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0132, -0.0000, -0.0000 ], [ -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000, -0.0000 ], [ -0.0000, -0.0000, 0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0396, -0.0000, -0.0000 ], ]), torch.Tensor([ [ -0.0198, 0.1780, 0.1780, -0.0198, -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198 ], [ 0.0000, 0.1274, 0.0000, 0.0000, 0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396 ], [ -0.0352, 0.3164, 0.3164, -0.0352, -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000 ], [ 0.0000, 0.4878, 0.0000, 0.0000, 0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132 ], ]), torch.Tensor([ [ -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022, -0.0198, 0.1780, 0.1780, -0.0198 ], [ -0.0000, -0.0000, -0.0000, -0.0132, -0.0000, -0.0000, 0.0000, 0.1274, 0.0000, 0.0000 ], [ 0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, -0.0352, 0.3164, 0.3164, -0.0352 ], [ -0.0000, -0.0000, -0.0000, -0.0396, -0.0000, -0.0000, 0.0000, 0.4878, 0.0000, 0.0000 ], ]), torch.Tensor([ [ -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198, -0.0198, 0.0022, -0.0002, 0.0022 ], [ 0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396, -0.0000, -0.0000, 0.0000, 0.0015 ], [ -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, 0.0000, -0.0000 ], [ 0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132, -0.0000, -0.0000, 0.0000, 0.0044 ], ]), torch.Tensor([ [ 0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022 ], [ 0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000, -0.0000, -0.0542, -0.0000, -0.0000 ], [ -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039, 0.0000, -0.0000, -0.0000, 0.0000 ], [ 0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000, -0.0000, -0.0142, -0.0000, -0.0000 ], ]), torch.Tensor([ [-0.0002, 0.0022, 0.0022, -0.0002], [0.0000, 0.0044, 0.0000, 0.0000], [0.0000, -0.0000, -0.0000, 0.0000], [0.0000, 0.0015, 0.0000, 0.0000], ]), ], 1) self.assertTrue(utils.approx_equal(values.data, actual_values))
def _compute_grid(self, x1, x2): if not self.has_grid: raise RuntimeError( 'GridInterpolationKernel requires setting the interpolation grid' ) n, d = x1.size() m, _ = x2.size() if d > 1: Js1 = x1.data.new(d, len(x1.data), 4).zero_().long() Cs1 = x1.data.new(d, len(x1.data), 4).zero_() Js2 = x1.data.new(d, len(x1.data), 4).zero_().long() Cs2 = x1.data.new(d, len(x1.data), 4).zero_() for i in range(d): both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[i] both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[i] if both_min < self.grid_bounds[i][ 0] or both_max > self.grid_bounds[i][1]: # Out of bounds data is still ok if we are specifically computing kernel values for grid entries. if math.fabs(both_min - self.grid[i, 0]) > 1e-7: raise RuntimeError( 'Received data that was out of bounds for the specified grid. \ Grid bounds were ({}, {}), but min = {}, \ max = {}'.format( self.grid_bounds[i][0], self.grid_bounds[i][1], both_min, both_max)) elif math.fabs(both_max - self.grid[i, -1]) > 1e-7: raise RuntimeError( 'Received data that was out of bounds for the specified grid. \ Grid bounds were ({}, {}), but min = {}, \ max = {}'.format( self.grid_bounds[i][0], self.grid_bounds[i][1], both_min, both_max)) Js1[i], Cs1[i] = Interpolation().interpolate( self.grid[i], x1.data[:, i]) Js2[i], Cs2[i] = Interpolation().interpolate( self.grid[i], x2.data[:, i]) return Js1, Cs1, Js2, Cs2 both_min = torch.min(x1.min(0)[0].data, x2.min(0)[0].data)[0] both_max = torch.max(x1.max(0)[0].data, x2.max(0)[0].data)[0] if both_min < self.grid_bounds[0][0] or both_max > self.grid_bounds[0][ 1]: # Out of bounds data is still ok if we are specifically computing kernel values for grid entries. if math.fabs(both_min - self.grid[0, 0]) > 1e-7: raise RuntimeError( 'Received data that was out of bounds for the specified grid. \ Grid bounds were ({}, {}), but min = {}, max = {}' .format(self.grid_bounds[0][0], self.grid_bounds[0][1], both_min, both_max)) elif math.fabs(both_max - self.grid[0, -1]) > 1e-7: raise RuntimeError( 'Received data that was out of bounds for the specified grid. \ Grid bounds were ({}, {}), but min = {}, max = {}' .format(self.grid_bounds[0][0], self.grid_bounds[0][1], both_min, both_max)) J1, C1 = Interpolation().interpolate(self.grid[0], x1.data.squeeze()) J2, C2 = Interpolation().interpolate(self.grid[0], x2.data.squeeze()) return J1, C1, J2, C2