def test_coherent(self, hbar, tol):
        """Test that E(n) = |a|^2 and var(n) = |a|^2 for a coherent state"""
        a = 0.23 + 0.12j
        mu = np.array([a.real, a.imag]) * np.sqrt(2 * hbar)
        cov = np.identity(2) * hbar / 2

        mean_photon, var = symplectic.mean_photon_number(mu, cov, hbar=hbar)

        assert np.allclose(mean_photon, np.abs(a)**2, atol=tol, rtol=0)
        assert np.allclose(var, np.abs(a)**2, atol=tol, rtol=0)
    def test_displaced_thermal(self, hbar, tol):
        """Test that E(n)=|a|^2+nbar and var(n)=var_th+|a|^2(1+2nbar)"""

        a = 0.12 - 0.05j
        nbar = 0.123

        mu = np.array([a.real, a.imag]) * np.sqrt(2 * hbar)
        cov = np.diag(2 * np.tile(nbar, 2) + 1) * (hbar / 2)

        mean_photon, var = symplectic.mean_photon_number(mu, cov, hbar=hbar)

        mean_ex = np.abs(a)**2 + nbar
        var_ex = nbar**2 + nbar + np.abs(a)**2 * (1 + 2 * nbar)

        assert np.allclose(mean_photon, mean_ex, atol=tol, rtol=0)
        assert np.allclose(var, var_ex, atol=tol, rtol=0)
    def test_displaced_squeezed(self, hbar, tol):
        """Test that E(n) = sinh^2(r)+|a|^2 for a displaced squeezed state"""
        a = 0.12 - 0.05j
        r = 0.1
        phi = 0.423

        S = np.array([
            [np.cosh(r) - np.cos(phi) * np.sinh(r), -np.sin(phi) * np.sinh(r)],
            [-np.sin(phi) * np.sinh(r),
             np.cosh(r) + np.cos(phi) * np.sinh(r)],
        ])

        mu = np.array([a.real, a.imag]) * np.sqrt(2 * hbar)
        cov = S @ S.T * hbar / 2

        mean_photon, _ = symplectic.mean_photon_number(mu, cov, hbar=hbar)

        mean_ex = np.abs(a)**2 + np.sinh(r)**2
        assert np.allclose(mean_photon, mean_ex, atol=tol, rtol=0)
    def test_squeezed(self, hbar, tol):
        """Test that E(n)=sinh^2(r) and var(n)=2(sinh^2(r)+sinh^4(r)) for a squeezed state"""
        r = 0.1
        phi = 0.423

        S = np.array([
            [np.cosh(r) - np.cos(phi) * np.sinh(r), -np.sin(phi) * np.sinh(r)],
            [-np.sin(phi) * np.sinh(r),
             np.cosh(r) + np.cos(phi) * np.sinh(r)],
        ])

        mu = np.zeros([2])
        cov = S @ S.T * hbar / 2

        mean_photon, var = symplectic.mean_photon_number(mu, cov, hbar=hbar)

        assert np.allclose(mean_photon, np.sinh(r)**2, atol=tol, rtol=0)
        assert np.allclose(var,
                           2 * (np.sinh(r)**2 + np.sinh(r)**4),
                           atol=tol,
                           rtol=0)