def test_merge_moments():
    q = cirq.LineQubit.range(3)
    c_orig = cirq.Circuit(
        cirq.Z.on_each(q[0], q[1]),
        cirq.Z.on_each(q[1], q[2]),
        cirq.Z.on_each(q[1], q[0]),
        strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
    )
    c_orig = cirq.Circuit(c_orig, cirq.CCX(*q), c_orig)
    cirq.testing.assert_has_diagram(
        c_orig,
        '''
0: ───Z───────Z───@───Z───────Z───
                  │
1: ───Z───Z───Z───@───Z───Z───Z───
                  │
2: ───────Z───────X───────Z───────
''',
    )

    cirq.testing.assert_has_diagram(
        cirq.merge_moments(c_orig, _merge_z_moments_func),
        '''
0: ───────@───────
          │
1: ───Z───@───Z───
          │
2: ───Z───X───Z───
''',
    )
def test_merge_moments_deep():
    q = cirq.LineQubit.range(3)
    c_z_moments = cirq.Circuit(
        [cirq.Z.on_each(q[0], q[1]), cirq.Z.on_each(q[1], q[2]), cirq.Z.on_each(q[1], q[0])],
        strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
    )
    merged_z_moment = cirq.Moment(cirq.Z.on_each(*q[1:]))
    c_nested_circuit = cirq.FrozenCircuit(c_z_moments, cirq.CCX(*q), c_z_moments)
    c_merged_circuit = cirq.FrozenCircuit(merged_z_moment, cirq.CCX(*q), merged_z_moment)
    c_orig = cirq.Circuit(
        cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
        c_nested_circuit,
        cirq.CircuitOperation(c_nested_circuit).repeat(6).with_tags("preserve_tag"),
        c_nested_circuit,
        cirq.CircuitOperation(c_nested_circuit).repeat(7),
    )
    c_expected = cirq.Circuit(
        cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
        c_merged_circuit,
        cirq.CircuitOperation(c_merged_circuit).repeat(6).with_tags("preserve_tag"),
        c_merged_circuit,
        cirq.CircuitOperation(c_merged_circuit).repeat(7),
    )
    cirq.testing.assert_same_circuits(
        cirq.merge_moments(c_orig, _merge_z_moments_func, tags_to_ignore=("ignore",), deep=True),
        c_expected,
    )
def test_merge_moments_empty_moment_as_intermediate_step():
    q = cirq.NamedQubit("q")
    c_orig = cirq.Circuit([cirq.X(q), cirq.Y(q), cirq.Z(q)] * 2, cirq.X(q) ** 0.5)

    def merge_func(m1: cirq.Moment, m2: cirq.Moment):
        gate = cirq.single_qubit_matrix_to_phxz(cirq.unitary(cirq.Circuit(m1, m2)), atol=1e-8)
        return cirq.Moment(gate.on(q) if gate else [])

    c_new = cirq.merge_moments(c_orig, merge_func)
    assert len(c_new) == 1
    assert isinstance(c_new[0][q].gate, cirq.PhasedXZGate)
    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-8)
def test_merge_moments():
    q = cirq.LineQubit.range(3)
    c_orig = cirq.Circuit(
        cirq.Z.on_each(q[0], q[1]),
        cirq.Z.on_each(q[1], q[2]),
        cirq.Z.on_each(q[1], q[0]),
        strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
    )
    c_orig = cirq.Circuit(c_orig, cirq.CCX(*q), c_orig)
    cirq.testing.assert_has_diagram(
        c_orig,
        '''
0: ───Z───────Z───@───Z───────Z───
                  │
1: ───Z───Z───Z───@───Z───Z───Z───
                  │
2: ───────Z───────X───────Z───────
''',
    )

    def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
        def is_z_moment(m):
            return all(op.gate == cirq.Z for op in m)

        if not (is_z_moment(m1) and is_z_moment(m2)):
            return None
        qubits = m1.qubits | m2.qubits

        def mul(op1, op2):
            return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(
                op1 * op2)

        return cirq.Moment(
            mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits)

    cirq.testing.assert_has_diagram(
        cirq.merge_moments(c_orig, merge_func),
        '''
0: ───────@───────
          │
1: ───Z───@───Z───
          │
2: ───Z───X───Z───
''',
    )
def test_merge_moments_empty_circuit():
    def fail_if_called_func(*_):
        assert False

    c = cirq.Circuit()
    assert cirq.merge_moments(c, fail_if_called_func) is c