def test_matrix_mixture_repr():
    mix = [
        (0.5, np.array([[1, 0], [0, 1]], dtype=np.complex64)),
        (0.5, np.array([[0, 1], [1, 0]], dtype=np.complex64)),
    ]
    half_flip = cirq.MixedUnitaryChannel(mix, key='flip')
    assert (repr(half_flip) == """\
cirq.MixedUnitaryChannel(mixture=[\
(0.5, np.array([[(1+0j), 0j], [0j, (1+0j)]], dtype=np.complex64)), \
(0.5, np.array([[0j, (1+0j)], [(1+0j), 0j]], dtype=np.complex64))], \
key='flip')""")
def test_matrix_mixture_str():
    mix = [(0.5, np.array([[1, 0], [0, 1]])), (0.5, np.array([[0, 1], [1,
                                                                       0]]))]
    half_flip = cirq.MixedUnitaryChannel(mix)
    assert (str(half_flip) == """MixedUnitaryChannel([(0.5, array([[1, 0],
       [0, 1]])), (0.5, array([[0, 1],
       [1, 0]]))])""")
    half_flip_keyed = cirq.MixedUnitaryChannel(mix, key='flip')
    assert (
        str(half_flip_keyed) == """MixedUnitaryChannel([(0.5, array([[1, 0],
       [0, 1]])), (0.5, array([[0, 1],
       [1, 0]]))], key=flip)""")
def test_validate():
    mix = [
        (0.5, np.array([[1, 0], [0, 0]])),
        (0.5, np.array([[0, 0], [0, 1]])),
    ]
    with pytest.raises(ValueError, match='non-unitary'):
        _ = cirq.MixedUnitaryChannel(mixture=mix, key='m', validate=True)
def test_mix_mismatch_fails():
    op2 = np.zeros((4, 4))
    op2[1][1] = 1
    mix = [(0.5, np.array([[1, 0], [0, 0]])), (0.5, op2)]

    with pytest.raises(ValueError, match='Inconsistent unitary shapes'):
        _ = cirq.MixedUnitaryChannel(mixture=mix, key='m')
def test_nonqubit_mixture_fails():
    mix = [
        (0.5, np.array([[1, 0, 0], [0, 1, 0]])),
        (0.5, np.array([[0, 1, 0], [1, 0, 0]])),
    ]

    with pytest.raises(ValueError, match='Input mixture'):
        _ = cirq.MixedUnitaryChannel(mixture=mix, key='m')
def test_measured_mixture():
    # This behaves like an X-basis measurement.
    mm = cirq.MixedUnitaryChannel(mixture=((0.5, np.array([[1, 0], [0, 1]])),
                                           (0.5, np.array([[0, 1], [1, 0]]))),
                                  key='flip')
    q0 = cirq.LineQubit(0)
    circuit = cirq.Circuit(mm.on(q0), cirq.measure(q0, key='m'))
    sim = cirq.Simulator(seed=0)
    results = sim.run(circuit, repetitions=100)
    assert results.histogram(key='flip') == results.histogram(key='m')
def test_matrix_mixture_from_unitaries():
    q0 = cirq.LineQubit(0)
    mix = [(0.5, np.array([[1, 0], [0, 1]])), (0.5, np.array([[0, 1], [1, 0]]))]
    half_flip = cirq.MixedUnitaryChannel(mix, key='flip')
    assert cirq.measurement_key_name(half_flip) == 'flip'

    circuit = cirq.Circuit(half_flip.on(q0), cirq.measure(q0, key='m'))
    sim = cirq.Simulator(seed=0)

    results = sim.simulate(circuit)
    assert 'flip' in results.measurements
    assert results.measurements['flip'] == results.measurements['m']
def test_matrix_mixture_equality():
    dp_pt1 = cirq.depolarize(0.1)
    dp_pt2 = cirq.depolarize(0.2)
    mm_a1 = cirq.MixedUnitaryChannel.from_mixture(dp_pt1, key='a')
    mm_a2 = cirq.MixedUnitaryChannel.from_mixture(dp_pt2, key='a')
    mm_b1 = cirq.MixedUnitaryChannel.from_mixture(dp_pt1, key='b')

    # Even if their effect is the same, MixedUnitaryChannels are not treated
    # as equal to other channels defined in Cirq.
    assert mm_a1 != dp_pt1
    assert mm_a1 != mm_a2
    assert mm_a1 != mm_b1
    assert mm_a2 != mm_b1

    mix = [(0.5, np.array([[1, 0], [0, 1]])), (0.5, np.array([[0, 1], [1, 0]]))]
    half_flip = cirq.MixedUnitaryChannel(mix)
    mix_inv = list(reversed(mix))
    half_flip_inv = cirq.MixedUnitaryChannel(mix_inv)
    # Even though these have the same effect on the circuit, their measurement
    # behavior differs, so they are considered non-equal.
    assert half_flip != half_flip_inv
def test_mix_bad_prob_fails():
    mix = [(0.5, np.array([[1, 0], [0, 0]]))]

    with pytest.raises(ValueError,
                       match='Unitary probabilities must sum to 1'):
        _ = cirq.MixedUnitaryChannel(mixture=mix, key='m')
def test_mix_no_unitaries_fails():
    with pytest.raises(ValueError, match='must have at least one unitary'):
        _ = cirq.MixedUnitaryChannel(mixture=[], key='m')