def test_rq_decomposition(): random_matrix = torch.rand([10, 10], dtype=torch.float64) r, q = decompositions.rq_decomposition(torch, random_matrix, 1) np.testing.assert_allclose(r.mm(q), random_matrix)
def rq_decomposition( self, tensor: Tensor, split_axis: int, ) -> Tuple[Tensor, Tensor]: return decompositions.rq_decomposition(self.torch, tensor, split_axis)
def test_expected_shapes_rq(): val = torch.zeros((2, 3, 4, 5)) r, q = decompositions.rq_decomposition(torch, val, 2) assert r.shape == (2, 3, 6) assert q.shape == (6, 4, 5)