def test_initial_state_vector():
    qubits = cirq.LineQubit.range(3)
    args = cirq.DensityMatrixSimulationState(qubits=qubits,
                                             initial_state=np.full(
                                                 (8, ), 1 / np.sqrt(8)),
                                             dtype=np.complex64)
    assert args.target_tensor.shape == (2, 2, 2, 2, 2, 2)

    args2 = cirq.DensityMatrixSimulationState(qubits=qubits,
                                              initial_state=np.full(
                                                  (2, 2, 2), 1 / np.sqrt(8)),
                                              dtype=np.complex64)
    assert args2.target_tensor.shape == (2, 2, 2, 2, 2, 2)
def test_with_qubits():
    original = cirq.DensityMatrixSimulationState(
        qubits=cirq.LineQubit.range(1), initial_state=1, dtype=np.complex64)
    extened = original.with_qubits(cirq.LineQubit.range(1, 2))
    np.testing.assert_almost_equal(
        extened.target_tensor,
        cirq.density_matrix_kronecker_product(
            np.array([[0, 0], [0, 1]], dtype=np.complex64),
            np.array([[1, 0], [0, 0]], dtype=np.complex64),
        ),
    )
def test_initial_state_bad_shape():
    qubits = cirq.LineQubit.range(3)
    with pytest.raises(ValueError, match="Invalid quantum state"):
        cirq.DensityMatrixSimulationState(qubits=qubits,
                                          initial_state=np.full((4, ), 1 / 2),
                                          dtype=np.complex64)
    with pytest.raises(ValueError, match="Invalid quantum state"):
        cirq.DensityMatrixSimulationState(qubits=qubits,
                                          initial_state=np.full((2, 2), 1 / 2),
                                          dtype=np.complex64)

    with pytest.raises(ValueError, match="Invalid quantum state"):
        cirq.DensityMatrixSimulationState(qubits=qubits,
                                          initial_state=np.full((4, 4), 1 / 4),
                                          dtype=np.complex64)
    with pytest.raises(ValueError, match="Invalid quantum state"):
        cirq.DensityMatrixSimulationState(qubits=qubits,
                                          initial_state=np.full((2, 2, 2, 2),
                                                                1 / 4),
                                          dtype=np.complex64)
def test_cannot_act():
    class NoDetails:
        pass

    args = cirq.DensityMatrixSimulationState(
        qubits=cirq.LineQubit.range(1),
        prng=np.random.RandomState(),
        initial_state=0,
        dtype=np.complex64,
    )
    with pytest.raises(TypeError, match="Can't simulate operations"):
        cirq.act_on(NoDetails(), args, qubits=())
def test_default_parameter():
    qid_shape = (2, )
    tensor = cirq.to_valid_density_matrix(0,
                                          len(qid_shape),
                                          qid_shape=qid_shape,
                                          dtype=np.complex64)
    args = cirq.DensityMatrixSimulationState(qubits=cirq.LineQubit.range(1),
                                             initial_state=0)
    np.testing.assert_almost_equal(args.target_tensor, tensor)
    assert len(args.available_buffer) == 3
    for buffer in args.available_buffer:
        assert buffer.shape == tensor.shape
        assert buffer.dtype == tensor.dtype
    assert args.qid_shape == qid_shape
def test_decomposed_fallback():
    class Composite(cirq.Gate):
        def num_qubits(self) -> int:
            return 1

        def _decompose_(self, qubits):
            yield cirq.X(*qubits)

    args = cirq.DensityMatrixSimulationState(
        qubits=cirq.LineQubit.range(1),
        prng=np.random.RandomState(),
        initial_state=0,
        dtype=np.complex64,
    )

    cirq.act_on(Composite(), args, cirq.LineQubit.range(1))
    np.testing.assert_allclose(
        args.target_tensor,
        cirq.one_hot(index=(1, 1), shape=(2, 2), dtype=np.complex64))
def test_shallow_copy_buffers():
    args = cirq.DensityMatrixSimulationState(qubits=cirq.LineQubit.range(1),
                                             initial_state=0)
    copy = args.copy(deep_copy_buffers=False)
    assert copy.available_buffer is args.available_buffer