def test_partial_trace_of_state_vector_as_mixture_mixed_result(): state = np.array([1, 0, 0, 1]) / np.sqrt(2) truth = ((0.5, np.array([1, 0])), (0.5, np.array([0, 1]))) for q1 in [0, 1]: mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1], atol=1e-8) assert mixtures_equal(mixture, truth) state = np.array([0, 1, 1, 0, 1, 0, 0, 0]).reshape((2, 2, 2)) / np.sqrt(3) truth = ((1 / 3, np.array([0.0, 1.0])), (2 / 3, np.array([1.0, 0.0]))) for q1 in [0, 1, 2]: mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1], atol=1e-8) assert mixtures_equal(mixture, truth) state = np.array([1, 0, 0, 0, 0, 0, 0, 1]).reshape((2, 2, 2)) / np.sqrt(2) truth = ((0.5, np.array([1, 0])), (0.5, np.array([0, 1]))) for q1 in [0, 1, 2]: mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1], atol=1e-8) assert mixtures_equal(mixture, truth) truth = ( (0.5, np.array([1, 0, 0, 0]).reshape((2, 2))), (0.5, np.array([0, 0, 0, 1]).reshape((2, 2))), ) for (q1, q2) in [(0, 1), (0, 2), (1, 2)]: mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1, q2], atol=1e-8) assert mixtures_equal(mixture, truth)
def test_partial_trace_of_state_vector_as_mixture_mixed_result_qudits(): state = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]]) / np.sqrt(2) truth = ((0.5, np.array([1, 0, 0])), (0.5, np.array([0, 0, 1]))) for q1 in [0, 1]: mixture = cirq.partial_trace_of_state_vector_as_mixture(state, [q1], atol=1e-8) assert mixtures_equal(mixture, truth)
def test_deprecated(): a = np.arange(4) / np.linalg.norm(np.arange(4)) with cirq.testing.assert_logs('subwavefunction', 'sub_state_vector', 'deprecated'): _ = cirq.subwavefunction(a, [0, 1], atol=1e-8) with cirq.testing.assert_logs('wavefunction', 'state_vector', 'deprecated'): # pylint: disable=unexpected-keyword-arg,no-value-for-parameter _ = cirq.sub_state_vector(wavefunction=a, keep_indices=[0, 1], atol=1e-8) with cirq.testing.assert_logs( 'wavefunction_partial_trace_as_mixture', 'partial_trace_of_state_vector_as_mixture', 'deprecated', ): _ = cirq.wavefunction_partial_trace_as_mixture(a, [0]) with cirq.testing.assert_logs('wavefunction', 'state_vector', 'deprecated'): # pylint: disable=unexpected-keyword-arg,no-value-for-parameter _ = cirq.partial_trace_of_state_vector_as_mixture(wavefunction=a, keep_indices=[0])
def get_state(self, resolver: cirq.ParamResolver, trace_dims: list = None, max_traced: bool = True): state_vector = cirq.final_state_vector(cirq.resolve_parameters(self.circuit, resolver)) if trace_dims and len(trace_dims) < self.get_circuit_size(): state_vector = cirq.partial_trace_of_state_vector_as_mixture(state_vector, trace_dims) if max_traced: prob, state_vector = max(state_vector, key=lambda el: el[0]) print(f"Max probability state after tracing has probability: {prob}") return state_vector, abs(state_vector) return state_vector, abs(state_vector)
def test_partial_trace_of_state_vector_as_mixture_invalid_input(): with pytest.raises(ValueError, match='7'): cirq.partial_trace_of_state_vector_as_mixture(np.arange(7), [1, 2], atol=1e-8) with pytest.raises(ValueError, match='normalized'): cirq.partial_trace_of_state_vector_as_mixture(np.arange(8), [1], atol=1e-8) state = np.arange(8) / np.linalg.norm(np.arange(8)) with pytest.raises(ValueError, match='repeated axis'): cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2, 2], atol=1e-8) state = np.array([1, 0, 0, 0]).reshape((2, 2)) with pytest.raises(IndexError, match='out of range'): cirq.partial_trace_of_state_vector_as_mixture(state, [5], atol=1e-8) with pytest.raises(IndexError, match='out of range'): cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1, 2], atol=1e-8)
def test_partial_trace_of_state_vector_as_mixture_pure_result_qudits(): a = cirq.testing.random_superposition(2) b = cirq.testing.random_superposition(3) c = cirq.testing.random_superposition(4) state = np.kron(np.kron(a, b), c).reshape((2, 3, 4)) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0], atol=1e-8), ((1.0, a), ), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [1], atol=1e-8), ((1.0, b), ), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [2], atol=1e-8), ((1.0, c), ), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1], atol=1e-8), ((1.0, np.kron(a, b).reshape((2, 3))), ), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0, 2], atol=1e-8), ((1.0, np.kron(a, c).reshape((2, 4))), ), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2], atol=1e-8), ((1.0, np.kron(b, c).reshape((3, 4))), ), )
def test_partial_trace_of_state_vector_as_mixture_pure_result(): a = cirq.testing.random_superposition(4) b = cirq.testing.random_superposition(8) c = cirq.testing.random_superposition(16) state = np.kron(np.kron(a, b), c).reshape((2,) * 9) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1], atol=1e-8), ((1.0, a.reshape(2, 2)),), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [2, 3, 4], atol=1e-8), ((1.0, b.reshape(2, 2, 2)),), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [5, 6, 7, 8], atol=1e-8), ((1.0, c.reshape(2, 2, 2, 2)),), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1, 2, 3, 4], atol=1e-8), ((1.0, np.kron(a, b).reshape((2, 2, 2, 2, 2))),), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1, 5, 6, 7, 8], atol=1e-8), ((1.0, np.kron(a, c).reshape((2, 2, 2, 2, 2, 2))),), ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [2, 3, 4, 5, 6, 7, 8], atol=1e-8), ((1.0, np.kron(b, c).reshape((2, 2, 2, 2, 2, 2, 2))),), ) # Shapes of states in the output mixture conform to the input's shape. state = state.reshape(2 ** 9) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1], atol=1e-8), ((1.0, a),) ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [2, 3, 4], atol=1e-8), ((1.0, b),) ) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [5, 6, 7, 8], atol=1e-8), ((1.0, c),) ) # Return mixture will defer to numpy.linalg.eigh's builtin tolerance. state = np.array([1, 0, 0, 1]) / np.sqrt(2) truth = ((0.5, np.array([1, 0])), (0.5, np.array([0, 1]))) assert mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [1], atol=1e-20), truth, atol=1e-15 ) assert not mixtures_equal( cirq.partial_trace_of_state_vector_as_mixture(state, [1], atol=1e-20), truth, atol=1e-16 )
def test_partial_trace_of_state_vector_as_mixture_invalid_input(): with pytest.raises(ValueError, match='7'): cirq.partial_trace_of_state_vector_as_mixture(np.arange(7), [1, 2], atol=1e-8) bad_shape = np.arange(16).reshape((2, 4, 2)) with pytest.raises(ValueError, match='shaped'): cirq.partial_trace_of_state_vector_as_mixture(bad_shape, [1], atol=1e-8) bad_shape = np.arange(16).reshape((16, 1)) with pytest.raises(ValueError, match='shaped'): cirq.partial_trace_of_state_vector_as_mixture(bad_shape, [1], atol=1e-8) with pytest.raises(ValueError, match='normalized'): cirq.partial_trace_of_state_vector_as_mixture(np.arange(8), [1], atol=1e-8) state = np.arange(8) / np.linalg.norm(np.arange(8)) with pytest.raises(ValueError, match='2, 2'): cirq.partial_trace_of_state_vector_as_mixture(state, [1, 2, 2], atol=1e-8) state = np.array([1, 0, 0, 0]).reshape((2, 2)) with pytest.raises(ValueError, match='invalid'): cirq.partial_trace_of_state_vector_as_mixture(state, [5], atol=1e-8) with pytest.raises(ValueError, match='invalid'): cirq.partial_trace_of_state_vector_as_mixture(state, [0, 1, 2], atol=1e-8)