def test_deprecated(): a = np.arange(4) / np.linalg.norm(np.arange(4)) with cirq.testing.assert_logs('subwavefunction', 'sub_state_vector', 'deprecated'): _ = cirq.subwavefunction(a, [0, 1], atol=1e-8) with cirq.testing.assert_logs('wavefunction', 'state_vector', 'deprecated'): # pylint: disable=unexpected-keyword-arg,no-value-for-parameter _ = cirq.sub_state_vector(wavefunction=a, keep_indices=[0, 1], atol=1e-8) with cirq.testing.assert_logs( 'wavefunction_partial_trace_as_mixture', 'partial_trace_of_state_vector_as_mixture', 'deprecated', ): _ = cirq.wavefunction_partial_trace_as_mixture(a, [0]) with cirq.testing.assert_logs('wavefunction', 'state_vector', 'deprecated'): # pylint: disable=unexpected-keyword-arg,no-value-for-parameter _ = cirq.partial_trace_of_state_vector_as_mixture(wavefunction=a, keep_indices=[0])
def _substate( one_state: tf.Tensor ) -> tf.Tensor: import cirq one_substate = cirq.subwavefunction( one_state.numpy(), keep_indices=keep_indices, atol=atol ) return tf.convert_to_tensor(one_substate)
def test_subwavefunction_non_kron(): a = np.array([1, 0, 0, 0, 0, 0, 0, 1]) / np.sqrt(2) b = np.array([1, 1]) / np.sqrt(2) state = np.kron(a, b).reshape((2, 2, 2, 2)) for q1 in [0, 1, 2]: assert cirq.subwavefunction(state, [q1], default=None, atol=1e-8) is None for q1 in [0, 1, 2]: assert cirq.subwavefunction(state, [q1, 3], default=None, atol=1e-8) is None with pytest.raises(ValueError, match='factored'): _ = cirq.subwavefunction(a, [0], atol=1e-8) assert cirq.equal_up_to_global_phase(cirq.subwavefunction(state, [3]), b, atol=1e-8)
def test_subwavefunction_bad_subset(): a = cirq.testing.random_superposition(4) b = cirq.testing.random_superposition(8) state = np.kron(a, b).reshape((2, 2, 2, 2, 2)) for q1 in range(5): assert cirq.subwavefunction(state, [q1], default=None, atol=1e-8) is None for q1 in range(2): for q2 in range(2, 5): assert cirq.subwavefunction( state, [q1, q2], default=None, atol=1e-8) is None for q3 in range(2, 5): assert cirq.subwavefunction(state, [0, 1, q3], default=None, atol=1e-8) is None for q4 in range(2): assert cirq.subwavefunction( state, [2, 3, 4, q4], default=None, atol=1e-8) is None
def test_subwavefunction_invalid_inputs(): # State cannot be expressed as a qubit wavefunction. with pytest.raises(ValueError, match='7'): cirq.subwavefunction(np.arange(7), [1, 2], atol=1e-8) # State shape does not conform to input requirements. with pytest.raises(ValueError, match='shaped'): cirq.subwavefunction(np.arange(16).reshape((2, 4, 2)), [1, 2], atol=1e-8) with pytest.raises(ValueError, match='shaped'): cirq.subwavefunction(np.arange(16).reshape((16, 1)), [1, 2], atol=1e-8) with pytest.raises(ValueError, match='normalized'): cirq.subwavefunction(np.arange(16), [1, 2], atol=1e-8) # Bad choice of input indices. state = np.arange(8) / np.linalg.norm(np.arange(8)) with pytest.raises(ValueError, match='2, 2'): cirq.subwavefunction(state, [1, 2, 2], atol=1e-8) state = np.array([1, 0, 0, 0]).reshape((2, 2)) with pytest.raises(ValueError, match='invalid'): cirq.subwavefunction(state, [5], atol=1e-8) with pytest.raises(ValueError, match='invalid'): cirq.subwavefunction(state, [0, 1, 2], atol=1e-8)
def test_subwavefunction(): a = np.arange(4) / np.linalg.norm(np.arange(4)) b = (np.arange(8) + 3) / np.linalg.norm(np.arange(8) + 3) c = (np.arange(16) + 1) / np.linalg.norm(np.arange(16) + 1) state = np.kron(np.kron(a, b), c).reshape(2, 2, 2, 2, 2, 2, 2, 2, 2) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(a, [0, 1], atol=1e-8), a) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(b, [0, 1, 2], atol=1e-8), b) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(c, [0, 1, 2, 3], atol=1e-8), c) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(state, [0, 1], atol=1e-15), a.reshape(2, 2)) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(state, [2, 3, 4], atol=1e-15), b.reshape(2, 2, 2)) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(state, [5, 6, 7, 8], atol=1e-15), c.reshape(2, 2, 2, 2)) # Output wavefunction conforms to the shape of the input wavefunction. reshaped_state = state.reshape(-1) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(reshaped_state, [0, 1], atol=1e-15), a) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(reshaped_state, [2, 3, 4], atol=1e-15), b) assert cirq.equal_up_to_global_phase( cirq.subwavefunction(reshaped_state, [5, 6, 7, 8], atol=1e-15), c) # Reject factoring for very tight tolerance. assert cirq.subwavefunction(state, [0, 1], default=None, atol=1e-16) is None assert cirq.subwavefunction(state, [2, 3, 4], default=None, atol=1e-16) is None assert cirq.subwavefunction(state, [5, 6, 7, 8], default=None, atol=1e-16) is None # Permit invalid factoring for loose tolerance. for q1 in range(9): assert cirq.subwavefunction(state, [q1], default=None, atol=1) is not None