Esempio n. 1
0
def test_default_tolerance():
    a, b = cirq.LineQubit.range(2)
    final_state_vector = (cirq.Simulator().simulate(
        cirq.Circuit(cirq.H(a), cirq.H(b), cirq.CZ(a, b),
                     cirq.measure(a))).final_state_vector.reshape((2, 2)))
    # Here, we do NOT specify the default tolerance. It is merely to check that the default value
    # is reasonable.
    cirq.sub_state_vector(final_state_vector, [0])
Esempio n. 2
0
def test_sub_state_vector_non_kron():
    a = np.array([1, 0, 0, 0, 0, 0, 0, 1]) / np.sqrt(2)
    b = np.array([1, 1]) / np.sqrt(2)
    state = np.kron(a, b).reshape((2, 2, 2, 2))

    for q1 in [0, 1, 2]:
        assert cirq.sub_state_vector(state, [q1], default=None, atol=1e-8) is None
    for q1 in [0, 1, 2]:
        assert cirq.sub_state_vector(state, [q1, 3], default=None, atol=1e-8) is None

    with pytest.raises(ValueError, match='factored'):
        _ = cirq.sub_state_vector(a, [0], atol=1e-8)

    assert cirq.equal_up_to_global_phase(cirq.sub_state_vector(state, [3]), b, atol=1e-8)
Esempio n. 3
0
def test_sub_state_vector_bad_subset():
    a = cirq.testing.random_superposition(4)
    b = cirq.testing.random_superposition(8)
    state = np.kron(a, b).reshape((2, 2, 2, 2, 2))

    for q1 in range(5):
        assert cirq.sub_state_vector(state, [q1], default=None, atol=1e-8) is None
    for q1 in range(2):
        for q2 in range(2, 5):
            assert cirq.sub_state_vector(state, [q1, q2], default=None, atol=1e-8) is None
    for q3 in range(2, 5):
        assert cirq.sub_state_vector(state, [0, 1, q3], default=None, atol=1e-8) is None
    for q4 in range(2):
        assert cirq.sub_state_vector(state, [2, 3, 4, q4], default=None, atol=1e-8) is None
Esempio n. 4
0
def test_deprecated():
    a = np.arange(4) / np.linalg.norm(np.arange(4))
    with cirq.testing.assert_logs('subwavefunction', 'sub_state_vector',
                                  'deprecated'):
        _ = cirq.subwavefunction(a, [0, 1], atol=1e-8)

    with cirq.testing.assert_logs('wavefunction', 'state_vector',
                                  'deprecated'):
        # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
        _ = cirq.sub_state_vector(wavefunction=a,
                                  keep_indices=[0, 1],
                                  atol=1e-8)

    with cirq.testing.assert_logs(
            'wavefunction_partial_trace_as_mixture',
            'partial_trace_of_state_vector_as_mixture',
            'deprecated',
    ):
        _ = cirq.wavefunction_partial_trace_as_mixture(a, [0])

    with cirq.testing.assert_logs('wavefunction', 'state_vector',
                                  'deprecated'):
        # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
        _ = cirq.partial_trace_of_state_vector_as_mixture(wavefunction=a,
                                                          keep_indices=[0])
Esempio n. 5
0
def test_sub_state_vector_invalid_inputs():

    # State cannot be expressed as a separable pure state.
    with pytest.raises(ValueError, match='7'):
        cirq.sub_state_vector(np.arange(7), [1, 2], atol=1e-8)

    # State shape does not conform to input requirements.
    with pytest.raises(ValueError, match='shaped'):
        cirq.sub_state_vector(np.arange(16).reshape((2, 4, 2)), [1, 2], atol=1e-8)
    with pytest.raises(ValueError, match='shaped'):
        cirq.sub_state_vector(np.arange(16).reshape((16, 1)), [1, 2], atol=1e-8)

    with pytest.raises(ValueError, match='normalized'):
        cirq.sub_state_vector(np.arange(16), [1, 2], atol=1e-8)

    # Bad choice of input indices.
    state = np.arange(8) / np.linalg.norm(np.arange(8))
    with pytest.raises(ValueError, match='2, 2'):
        cirq.sub_state_vector(state, [1, 2, 2], atol=1e-8)

    state = np.array([1, 0, 0, 0]).reshape((2, 2))
    with pytest.raises(ValueError, match='invalid'):
        cirq.sub_state_vector(state, [5], atol=1e-8)
    with pytest.raises(ValueError, match='invalid'):
        cirq.sub_state_vector(state, [0, 1, 2], atol=1e-8)
Esempio n. 6
0
def test_sub_state_vector():
    a = np.arange(4) / np.linalg.norm(np.arange(4))
    b = (np.arange(8) + 3) / np.linalg.norm(np.arange(8) + 3)
    c = (np.arange(16) + 1) / np.linalg.norm(np.arange(16) + 1)
    state = np.kron(np.kron(a, b), c).reshape(2, 2, 2, 2, 2, 2, 2, 2, 2)

    assert cirq.equal_up_to_global_phase(cirq.sub_state_vector(a, [0, 1], atol=1e-8), a)
    assert cirq.equal_up_to_global_phase(cirq.sub_state_vector(b, [0, 1, 2], atol=1e-8), b)
    assert cirq.equal_up_to_global_phase(cirq.sub_state_vector(c, [0, 1, 2, 3], atol=1e-8), c)

    assert cirq.equal_up_to_global_phase(
        cirq.sub_state_vector(state, [0, 1], atol=1e-15), a.reshape(2, 2)
    )
    assert cirq.equal_up_to_global_phase(
        cirq.sub_state_vector(state, [2, 3, 4], atol=1e-15), b.reshape(2, 2, 2)
    )
    assert cirq.equal_up_to_global_phase(
        cirq.sub_state_vector(state, [5, 6, 7, 8], atol=1e-15), c.reshape(2, 2, 2, 2)
    )

    # Output state vector conforms to the shape of the input state vector.
    reshaped_state = state.reshape(-1)
    assert cirq.equal_up_to_global_phase(
        cirq.sub_state_vector(reshaped_state, [0, 1], atol=1e-15), a
    )
    assert cirq.equal_up_to_global_phase(
        cirq.sub_state_vector(reshaped_state, [2, 3, 4], atol=1e-15), b
    )
    assert cirq.equal_up_to_global_phase(
        cirq.sub_state_vector(reshaped_state, [5, 6, 7, 8], atol=1e-15), c
    )

    # Reject factoring for very tight tolerance.
    assert cirq.sub_state_vector(state, [0, 1], default=None, atol=1e-16) is None
    assert cirq.sub_state_vector(state, [2, 3, 4], default=None, atol=1e-16) is None
    assert cirq.sub_state_vector(state, [5, 6, 7, 8], default=None, atol=1e-16) is None

    # Permit invalid factoring for loose tolerance.
    for q1 in range(9):
        assert cirq.sub_state_vector(state, [q1], default=None, atol=1) is not None