def test_unitary_size(self, tol):
        """test n=2 unitaries are not decomposed"""

        U = random_interferometer(2)
        with pytest.raises(
                ValueError,
                match="Input matrix for decomposition must be at least 3x3."):
            dec.sun_compact(U, tol)
    def test_interferometer_reconstruction(self, n, tol):
        """test numerical reconstruction of Unitary matrix equals the original matrix"""

        U = random_interferometer(n)

        factorization_params, phase = dec.sun_compact(U, tol)
        det = np.exp(1j * phase)**(1 / n)
        SU_reconstructed = self._sun_reconstruction(n, factorization_params)
        U_reconstructed = det * SU_reconstructed

        assert np.allclose(U_reconstructed, U)
    def test_SU_reconstruction(self, n, tol):
        """test numerical reconstruction of Special Unitary matrix equals the original matrix"""

        # Generate a random SU(n) matrix.
        U = random_interferometer(n)
        SU_expected = U / np.linalg.det(U)**(1 / n)

        # get result from factorization
        factorization_params, global_phase = dec.sun_compact(SU_expected, tol)

        SU_reconstructed = self._sun_reconstruction(n, factorization_params)

        assert global_phase is None
        assert np.allclose(SU_expected, SU_reconstructed, atol=tol, rtol=0)
    def test_embeded_unitary(self, n, permutation, phase, tol):
        """test factorization of U(n-1) transformations embeded on U(n) transformation"""

        # Embed U(4) on n=5 matrix
        U = np.zeros((n, n), dtype=complex)
        U[0, 0] = np.exp(1j * phase)
        U4 = random_interferometer(n - 1)
        U[1:, 1:] = U4

        # permute rows
        U = U[permutation, :]

        factorization_params, _ = dec.sun_compact(U, tol)
        _, first_params = factorization_params[0]

        assert first_params == [0.0, 0.0, 0.0]
    def test_global_phase(self, SU_matrix, tol):
        """test factorized phase from unitary matrix"""
        n = 3
        # Generate a random SU(n) matrix.
        U = random_interferometer(n)
        det = np.linalg.det(U)
        if SU_matrix:
            U /= det**(1 / n)

        # get result from factorization
        _, global_phase = dec.sun_compact(U, tol)

        if SU_matrix:
            assert global_phase is None
        else:
            expected_phase = np.angle(det)
            assert np.allclose(global_phase, expected_phase, atol=tol)
 def test_unitary_validation(self, tol):
     """Test that an exception is raised if not unitary"""
     A = np.random.random([5, 5]) + 1j * np.random.random([5, 5])
     with pytest.raises(ValueError,
                        match="The input matrix is not unitary."):
         dec.sun_compact(A, tol)