コード例 #1
0
ファイル: interp_test.py プロジェクト: soobleck/gpytorch
def test_left_t_interp_on_a_vector():
    vector = torch.randn(9)

    res = left_t_interp(interp_indices, interp_values, Variable(vector),
                        6).data
    actual = torch.matmul(interp_matrix.transpose(-1, -2), vector)
    assert approx_equal(res, actual)
コード例 #2
0
ファイル: interp_test.py プロジェクト: soobleck/gpytorch
def test_batch_left_t_interp_on_a_batch_matrix():
    batch_matrix = torch.randn(2, 9, 3)

    res = left_t_interp(batch_interp_indices, batch_interp_values,
                        Variable(batch_matrix), 6).data
    actual = torch.matmul(batch_interp_matrix.transpose(-1, -2), batch_matrix)
    assert approx_equal(res, actual)
コード例 #3
0
ファイル: interp_test.py プロジェクト: soobleck/gpytorch
def test_left_t_interp_on_a_matrix():
    matrix = torch.randn(9, 3)

    res = left_t_interp(interp_indices, interp_values, Variable(matrix),
                        6).data
    actual = torch.matmul(interp_matrix.transpose(-1, -2), matrix)
    assert approx_equal(res, actual)
コード例 #4
0
ファイル: interp_test.py プロジェクト: soobleck/gpytorch
def test_batch_left_t_interp_on_a_vector():
    vector = torch.randn(9)

    actual = torch.matmul(batch_interp_matrix.transpose(-1, -2),
                          vector.unsqueeze(-1).unsqueeze(0)).squeeze(0)
    res = left_t_interp(batch_interp_indices, batch_interp_values,
                        Variable(vector), 6).data
    assert approx_equal(res, actual)
コード例 #5
0
    def test_left_t_interp_on_a_matrix(self):
        matrix = torch.randn(9, 3)

        res = left_t_interp(
            self.interp_indices, self.interp_values, Variable(matrix), 6
        ).data
        actual = torch.matmul(self.interp_matrix.transpose(-1, -2), matrix)
        self.assertTrue(approx_equal(res, actual))
コード例 #6
0
ファイル: test_interp.py プロジェクト: colesbury/gpytorch
    def test_left_t_interp_on_a_vector(self):
        vector = torch.randn(9)

        res = left_t_interp(
            self.interp_indices,
            self.interp_values,
            Variable(vector),
            6,
        ).data
        actual = torch.matmul(self.interp_matrix.transpose(-1, -2), vector)
        self.assertTrue(approx_equal(res, actual))
コード例 #7
0
    def test_batch_left_t_interp_on_a_batch_matrix(self):
        batch_matrix = torch.randn(2, 9, 3)

        res = left_t_interp(
            self.batch_interp_indices,
            self.batch_interp_values,
            Variable(batch_matrix),
            6,
        ).data
        actual = torch.matmul(self.batch_interp_matrix.transpose(-1, -2), batch_matrix)
        self.assertTrue(approx_equal(res, actual))
コード例 #8
0
    def test_batch_left_t_interp_on_a_vector(self):
        vector = torch.randn(9)

        actual = torch.matmul(self.batch_interp_matrix.transpose(-1, -2),
                              vector.unsqueeze(-1).unsqueeze(0)).squeeze(-1)
        res = left_t_interp(
            self.batch_interp_indices,
            self.batch_interp_values,
            Variable(vector),
            6,
        ).data
        self.assertTrue(approx_equal(res, actual))