def test_batch_left_t_interp_on_a_batch_matrix(self): batch_matrix = torch.randn(2, 9, 3) res = left_t_interp(self.batch_interp_indices, self.batch_interp_values, batch_matrix, 6) actual = torch.matmul(self.batch_interp_matrix.transpose(-1, -2), batch_matrix) self.assertTrue(test._utils.approx_equal(res, actual))
def test_batch_left_t_interp_on_a_vector(self): vector = torch.randn(9) actual = torch.matmul(self.batch_interp_matrix.transpose(-1, -2), vector.unsqueeze(-1).unsqueeze(0)).squeeze(-1) res = left_t_interp(self.batch_interp_indices, self.batch_interp_values, vector, 6) self.assertTrue(test._utils.approx_equal(res, actual))
def test_left_t_interp_on_a_vector(self): vector = torch.randn(9) res = left_t_interp(self.interp_indices, self.interp_values, vector, 6) actual = torch.matmul(self.interp_matrix.transpose(-1, -2), vector) self.assertTrue(test._utils.approx_equal(res, actual))
def test_left_t_interp_on_a_matrix(self): matrix = torch.randn(9, 3) res = left_t_interp(self.interp_indices, self.interp_values, matrix, 6) actual = torch.matmul(self.interp_matrix.transpose(-1, -2), matrix) self.assertTrue(approx_equal(res, actual))