def test_qr_decomposition(): random_matrix = torch.rand([10, 10], dtype=torch.float64) q, r = decompositions.qr_decomposition(torch, random_matrix, 1) np.testing.assert_allclose(q.mm(r), random_matrix)
def qr_decomposition( self, tensor: Tensor, split_axis: int, ) -> Tuple[Tensor, Tensor]: return decompositions.qr_decomposition(self.torch, tensor, split_axis)
def test_expected_shapes_qr(): val = torch.zeros((2, 3, 4, 5)) q, r = decompositions.qr_decomposition(torch, val, 2) assert q.shape == (2, 3, 6) assert r.shape == (6, 4, 5)