コード例 #1
0
    def test_batch_left_t_interp_on_a_batch_matrix(self):
        batch_matrix = torch.randn(2, 9, 3)

        res = left_t_interp(self.batch_interp_indices,
                            self.batch_interp_values, batch_matrix, 6)
        actual = torch.matmul(self.batch_interp_matrix.transpose(-1, -2),
                              batch_matrix)
        self.assertTrue(test._utils.approx_equal(res, actual))
コード例 #2
0
    def test_batch_left_t_interp_on_a_vector(self):
        vector = torch.randn(9)

        actual = torch.matmul(self.batch_interp_matrix.transpose(-1, -2),
                              vector.unsqueeze(-1).unsqueeze(0)).squeeze(-1)
        res = left_t_interp(self.batch_interp_indices,
                            self.batch_interp_values, vector, 6)
        self.assertTrue(test._utils.approx_equal(res, actual))
コード例 #3
0
    def test_left_t_interp_on_a_vector(self):
        vector = torch.randn(9)

        res = left_t_interp(self.interp_indices, self.interp_values, vector, 6)
        actual = torch.matmul(self.interp_matrix.transpose(-1, -2), vector)
        self.assertTrue(test._utils.approx_equal(res, actual))
コード例 #4
0
    def test_left_t_interp_on_a_matrix(self):
        matrix = torch.randn(9, 3)

        res = left_t_interp(self.interp_indices, self.interp_values, matrix, 6)
        actual = torch.matmul(self.interp_matrix.transpose(-1, -2), matrix)
        self.assertTrue(approx_equal(res, actual))