def test_init_system_returns_expected_wavefunction_size(self, n_qubits):
        wavefunction = Wavefunction.init_system(n_qubits=n_qubits)

        # Check length
        assert len(wavefunction) == 2**n_qubits

        # Check internal property
        assert wavefunction.n_qubits == n_qubits

        # Check amplitude of zero state
        assert wavefunction[0] == 1.0

        # Check amplitude of the rest of the states
        assert not np.any(wavefunction[1:])
class TestRepresentations:
    def test_string_output_of_symbolic_wavefunction(self):
        wf = Wavefunction([Symbol("alpha"), 0])

        wf_str = wf.__str__()

        assert "alpha" in wf_str
        assert wf_str.endswith("])")
        assert wf_str.startswith("Wavefunction([")

    def test_string_output_of_numeric_wavefunction(self):
        wf = Wavefunction([1j, 0])

        wf_str = wf.__str__()

        assert "j" in wf_str
        assert wf_str.endswith("])")
        assert wf_str.startswith("Wavefunction([")

    @pytest.mark.parametrize(
        "wf",
        [Wavefunction.init_system(2),
         Wavefunction([Symbol("alpha"), 0.0])])
    def test_amplitudes_and_probs_output_type(self, wf: Wavefunction):
        if len(wf.free_symbols) > 0:
            assert wf.amplitudes.dtype == object
            assert wf.probabilities().dtype == object
        else:
            assert wf.amplitudes.dtype == np.complex128
            assert wf.probabilities().dtype == np.float64

    @pytest.mark.parametrize(
        "wf_vec",
        [
            [1.0, 0.0],
            [0.5, 0.5, 0.5, 0.5],
            [1 / sqrt(2), 0, 0, 0, 0, 0, 0, 1 / sqrt(2)],
        ],
    )
    def test_get_outcome_probs(self, wf_vec):
        wf = Wavefunction(wf_vec)
        probs_dict = wf.get_outcome_probs()

        assert all([len(key) == wf.n_qubits for key in probs_dict.keys()])

        for key in probs_dict.keys():
            assert len(key) == wf.n_qubits

            assert wf.probabilities()[int(key, 2)] == probs_dict[key]
 def test_init_system_fails_on_invalid_params(self, n_qubits):
     with pytest.raises(ValueError):
         Wavefunction.init_system(n_qubits=n_qubits)
 def test_init_system_raises_warning_for_non_ints(self):
     with pytest.warns(UserWarning):
         Wavefunction.init_system(1.234)
 def test_init_system_returns_numpy_array(self):
     wf = Wavefunction.init_system(2)
     assert isinstance(wf._amplitude_vector, np.ndarray)