def test_compute_linear_truncated_kernel_no_batch(self): x1 = torch.tensor([[1, 0.1, 0.2], [2, 0.3, 0.4]]) x2 = torch.tensor([[3, 0.5, 0.6], [4, 0.7, 0.8]]) t_1 = torch.tensor([[0.3584, 0.1856], [0.2976, 0.1584]]) for nu, fidelity_dims in itertools.product({0.5, 1.5, 2.5}, ([2], [1, 2])): kernel = LinearTruncatedFidelityKernel( fidelity_dims=fidelity_dims, dimension=3, nu=nu ) kernel.power = 1 n_fid = len(fidelity_dims) if n_fid > 1: active_dimsM = [0] t_2 = torch.tensor([[0.4725, 0.2889], [0.4025, 0.2541]]) t_3 = torch.tensor([[0.1685, 0.0531], [0.1168, 0.0386]]) t = 1 + t_1 + t_2 + t_3 else: active_dimsM = [0, 1] t = 1 + t_1 matern_ker = MaternKernel(nu=nu, active_dims=active_dimsM) matern_term = matern_ker(x1, x2).evaluate() actual = t * matern_term res = kernel(x1, x2).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # test diagonal mode res_diag = kernel(x1, x2, diag=True) self.assertLess(torch.norm(res_diag - actual.diag()), 1e-4) # make sure that we error out if last_dim_is_batch=True with self.assertRaises(NotImplementedError): kernel(x1, x2, diag=True, last_dim_is_batch=True)
def test_compute_linear_truncated_kernel_no_batch(self): x1 = torch.tensor([1, 0.1, 0.2, 2, 0.3, 0.4], dtype=torch.float).view(2, 3) x2 = torch.tensor([3, 0.5, 0.6, 4, 0.7, 0.8], dtype=torch.float).view(2, 3) t_1 = torch.tensor([0.3584, 0.1856, 0.2976, 0.1584], dtype=torch.float).view(2, 2) for nu in {0.5, 1.5, 2.5}: for fidelity_dims in ([2], [1, 2]): kernel = LinearTruncatedFidelityKernel( fidelity_dims=fidelity_dims, dimension=3, nu=nu) kernel.power = 1 if len(fidelity_dims) > 1: active_dimsM = [0] t_2 = torch.tensor([0.4725, 0.2889, 0.4025, 0.2541], dtype=torch.float).view(2, 2) t_3 = torch.tensor([0.1685, 0.0531, 0.1168, 0.0386], dtype=torch.float).view(2, 2) t = 1 + t_1 + t_2 + t_3 else: active_dimsM = [0, 1] t = 1 + t_1 matern_ker = MaternKernel(nu=nu, active_dims=active_dimsM) matern_term = matern_ker(x1, x2).evaluate() actual = t * matern_term res = kernel(x1, x2).evaluate() self.assertLess(torch.norm(res - actual), 1e-4)
def test_compute_linear_truncated_kernel_with_batch(self): x1 = torch.tensor([1, 0.1, 0.2, 3, 0.3, 0.4, 5, 0.5, 0.6, 7, 0.7, 0.8], dtype=torch.float).view(2, 2, 3) x2 = torch.tensor([2, 0.8, 0.7, 4, 0.6, 0.5, 6, 0.4, 0.3, 8, 0.2, 0.1], dtype=torch.float).view(2, 2, 3) t_1 = torch.tensor( [0.2736, 0.44, 0.2304, 0.36, 0.3304, 0.3816, 0.1736, 0.1944], dtype=torch.float, ).view(2, 2, 2) batch_shape = torch.Size([2]) for nu in {0.5, 1.5, 2.5}: for fidelity_dims in ([2], [1, 2]): kernel = LinearTruncatedFidelityKernel( fidelity_dims=fidelity_dims, dimension=3, nu=nu, batch_shape=batch_shape, ) kernel.power = 1 if len(fidelity_dims) > 1: active_dimsM = [0] t_2 = torch.tensor( [ 0.0527, 0.167, 0.0383, 0.1159, 0.1159, 0.167, 0.0383, 0.0527 ], dtype=torch.float, ).view(2, 2, 2) t_3 = torch.tensor( [ 0.1944, 0.3816, 0.1736, 0.3304, 0.36, 0.44, 0.2304, 0.2736 ], dtype=torch.float, ).view(2, 2, 2) t = 1 + t_1 + t_2 + t_3 else: active_dimsM = [0, 1] t = 1 + t_1 matern_ker = MaternKernel(nu=nu, active_dims=active_dimsM, batch_shape=batch_shape) matern_term = matern_ker(x1, x2).evaluate() actual = t * matern_term res = kernel(x1, x2).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # test diagonal mode res_diag = kernel(x1, x2, diag=True) self.assertLess( torch.norm(res_diag - torch.diagonal(actual, dim1=-1, dim2=-2)), 1e-4, ) # make sure that we error out if last_dim_is_batch=True with self.assertRaises(NotImplementedError): kernel(x1, x2, diag=True, last_dim_is_batch=True)
def test_compute_linear_truncated_kernel_with_batch(self): x1 = torch.tensor( [[[1.0, 0.1, 0.2], [3.0, 0.3, 0.4]], [[5.0, 0.5, 0.6], [7.0, 0.7, 0.8]]] ) x2 = torch.tensor( [[[2.0, 0.8, 0.7], [4.0, 0.6, 0.5]], [[6.0, 0.4, 0.3], [8.0, 0.2, 0.1]]] ) t_1 = torch.tensor( [[[0.2736, 0.4400], [0.2304, 0.3600]], [[0.3304, 0.3816], [0.1736, 0.1944]]] ) batch_shape = torch.Size([2]) for nu, fidelity_dims in itertools.product({0.5, 1.5, 2.5}, ([2], [1, 2])): kernel = LinearTruncatedFidelityKernel( fidelity_dims=fidelity_dims, dimension=3, nu=nu, batch_shape=batch_shape ) kernel.power = 1 if len(fidelity_dims) > 1: active_dimsM = [0] t_2 = torch.tensor( [ [[0.0527, 0.1670], [0.0383, 0.1159]], [[0.1159, 0.1670], [0.0383, 0.0527]], ] ) t_3 = torch.tensor( [ [[0.1944, 0.3816], [0.1736, 0.3304]], [[0.3600, 0.4400], [0.2304, 0.2736]], ] ) t = 1 + t_1 + t_2 + t_3 else: active_dimsM = [0, 1] t = 1 + t_1 matern_ker = MaternKernel( nu=nu, active_dims=active_dimsM, batch_shape=batch_shape ) matern_term = matern_ker(x1, x2).evaluate() actual = t * matern_term res = kernel(x1, x2).evaluate() self.assertLess(torch.norm(res - actual), 1e-4) # test diagonal mode res_diag = kernel(x1, x2, diag=True) self.assertLess( torch.norm(res_diag - torch.diagonal(actual, dim1=-1, dim2=-2)), 1e-4 ) # make sure that we error out if last_dim_is_batch=True with self.assertRaises(NotImplementedError): kernel(x1, x2, diag=True, last_dim_is_batch=True)
def test_compute_linear_truncated_kernel_with_batch(self): x1 = torch.tensor([1, 0.1, 0.2, 3, 0.3, 0.4, 5, 0.5, 0.6, 7, 0.7, 0.8], dtype=torch.float).view(2, 2, 3) x2 = torch.tensor([2, 0.8, 0.7, 4, 0.6, 0.5, 6, 0.4, 0.3, 8, 0.2, 0.1], dtype=torch.float).view(2, 2, 3) t_1 = torch.tensor( [0.2736, 0.44, 0.2304, 0.36, 0.3304, 0.3816, 0.1736, 0.1944], dtype=torch.float, ).view(2, 2, 2) batch_shape = torch.Size([2]) for nu in {0.5, 1.5, 2.5}: for fidelity_dims in ([2], [1, 2]): kernel = LinearTruncatedFidelityKernel( fidelity_dims=fidelity_dims, dimension=3, nu=nu, batch_shape=batch_shape, ) kernel.power = 1 if len(fidelity_dims) > 1: active_dimsM = [0] t_2 = torch.tensor( [ 0.0527, 0.167, 0.0383, 0.1159, 0.1159, 0.167, 0.0383, 0.0527 ], dtype=torch.float, ).view(2, 2, 2) t_3 = torch.tensor( [ 0.1944, 0.3816, 0.1736, 0.3304, 0.36, 0.44, 0.2304, 0.2736 ], dtype=torch.float, ).view(2, 2, 2) t = 1 + t_1 + t_2 + t_3 else: active_dimsM = [0, 1] t = 1 + t_1 matern_ker = MaternKernel(nu=nu, active_dims=active_dimsM, batch_shape=batch_shape) matern_term = matern_ker(x1, x2).evaluate() actual = t * matern_term res = kernel(x1, x2).evaluate() self.assertLess(torch.norm(res - actual), 1e-4)