Ejemplo n.º 1
0
def test_merge_1q_unitaries():
    q, q2 = cirq.LineQubit.range(2)
    # 1. Combines trivial 1q sequence.
    c = cirq.Circuit(cirq.X(q)**0.5, cirq.Z(q)**0.5, cirq.X(q)**-0.5)
    c = cirq.merge_k_qubit_unitaries(c, k=1)
    op_list = [*c.all_operations()]
    assert len(op_list) == 1
    assert isinstance(op_list[0].gate, cirq.MatrixGate)
    cirq.testing.assert_allclose_up_to_global_phase(cirq.unitary(c),
                                                    cirq.unitary(cirq.Y**0.5),
                                                    atol=1e-7)

    # 2. Gets blocked at a 2q operation.
    c = cirq.Circuit([
        cirq.Z(q),
        cirq.H(q),
        cirq.X(q),
        cirq.H(q),
        cirq.CZ(q, q2),
        cirq.H(q)
    ])
    c = cirq.drop_empty_moments(cirq.merge_k_qubit_unitaries(c, k=1))
    assert len(c) == 3
    cirq.testing.assert_allclose_up_to_global_phase(cirq.unitary(c[0]),
                                                    np.eye(2),
                                                    atol=1e-7)
    assert isinstance(c[-1][q].gate, cirq.MatrixGate)
Ejemplo n.º 2
0
def test_ignore_unsupported_gate():
    class UnsupportedDummy(cirq.testing.SingleQubitGate):
        pass

    c = cirq.Circuit(UnsupportedDummy()(cirq.LineQubit(0)))
    assert_optimizes(optimized=cirq.merge_k_qubit_unitaries(c, k=1),
                     expected=c)
Ejemplo n.º 3
0
def test_ignore_unsupported_gate():
    class UnsupportedDummy(cirq.Gate):
        def _num_qubits_(self) -> int:
            return 1

    c = cirq.Circuit(UnsupportedDummy()(cirq.LineQubit(0)))
    assert_optimizes(optimized=cirq.merge_k_qubit_unitaries(c, k=1), expected=c)
Ejemplo n.º 4
0
def test_1q_rewrite():
    q0, q1 = cirq.LineQubit.range(2)
    circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0), cirq.X(q1), cirq.CZ(q0, q1),
                           cirq.Y(q1), cirq.measure(q0, q1))
    assert_optimizes(
        optimized=cirq.merge_k_qubit_unitaries(
            circuit, k=1, rewriter=lambda ops: cirq.H(ops.qubits[0])),
        expected=cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.CZ(q0, q1),
                              cirq.H(q1), cirq.measure(q0, q1)),
    )
Ejemplo n.º 5
0
def test_respects_nocompile_tags():
    q = cirq.NamedQubit("q")
    c = cirq.Circuit(
        [cirq.Z(q), cirq.H(q), cirq.X(q), cirq.H(q), cirq.X(q).with_tags("nocompile"), cirq.H(q)]
    )
    context = cirq.TransformerContext(tags_to_ignore=("nocompile",))
    c = cirq.drop_empty_moments(cirq.merge_k_qubit_unitaries(c, k=1, context=context))
    assert len(c) == 3
    cirq.testing.assert_allclose_up_to_global_phase(cirq.unitary(c[0]), np.eye(2), atol=1e-7)
    assert c[1][q] == cirq.X(q).with_tags("nocompile")
    assert isinstance(c[-1][q].gate, cirq.MatrixGate)
Ejemplo n.º 6
0
def test_merge_k_qubit_unitaries_deep_recurses_on_large_circuit_op():
    q = cirq.LineQubit.range(2)
    c_orig = cirq.Circuit(
        cirq.CircuitOperation(
            cirq.FrozenCircuit(cirq.X(q[0]), cirq.H(q[0]), cirq.CNOT(*q))))
    c_expected = cirq.Circuit(
        cirq.CircuitOperation(
            cirq.FrozenCircuit(
                cirq.CircuitOperation(
                    cirq.FrozenCircuit(cirq.X(q[0]),
                                       cirq.H(q[0]))).with_tags("merged"),
                cirq.CNOT(*q),
            )))
    c_new = cirq.merge_k_qubit_unitaries(
        c_orig,
        context=cirq.TransformerContext(deep=True),
        k=1,
        rewriter=lambda op: op.with_tags("merged"),
    )
    cirq.testing.assert_same_circuits(c_new, c_expected)
Ejemplo n.º 7
0
 def _decompose_two_qubit_operation(self, op: cirq.Operation,
                                    _) -> cirq.OP_TREE:
     if not cirq.has_unitary(op):
         return NotImplemented
     mat = cirq.unitary(op)
     q0, q1 = op.qubits
     naive = cirq.two_qubit_matrix_to_cz_operations(q0,
                                                    q1,
                                                    mat,
                                                    allow_partial_czs=False)
     temp = cirq.map_operations_and_unroll(
         cirq.Circuit(naive),
         lambda op, _: [
             cirq.H(op.qubits[1]),
             cirq.CNOT(*op.qubits),
             cirq.H(op.qubits[1])
         ] if op.gate == cirq.CZ else op,
     )
     return cirq.merge_k_qubit_unitaries(
         temp,
         k=1,
         rewriter=lambda op: self._decompose_single_qubit_operation(
             op, -1)).all_operations()
Ejemplo n.º 8
0
def test_ignores_2qubit_target():
    c = cirq.Circuit(cirq.CZ(*cirq.LineQubit.range(2)))
    assert_optimizes(optimized=cirq.merge_k_qubit_unitaries(c, k=1),
                     expected=c)
Ejemplo n.º 9
0
def test_merge_k_qubit_unitaries_deep():
    q = cirq.LineQubit.range(2)
    h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])]
    c_orig = cirq.Circuit(
        h_cz_y,
        cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
        cirq.CircuitOperation(
            cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
        [cirq.CNOT(*q), cirq.CNOT(*q)],
        cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4),
        [cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)],
        cirq.CircuitOperation(
            cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"),
    )

    def _wrap_in_cop(ops: cirq.OP_TREE, tag: str):
        return cirq.CircuitOperation(cirq.FrozenCircuit(ops)).with_tags(tag)

    c_expected = cirq.Circuit(
        _wrap_in_cop([h_cz_y, cirq.Y(q[1])], '1'),
        cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
        cirq.CircuitOperation(
            cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
        _wrap_in_cop([cirq.CNOT(*q), cirq.CNOT(*q)], '2'),
        cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(h_cz_y,
                                                              '3'))).repeat(4),
        _wrap_in_cop([cirq.CNOT(*q), cirq.CZ(*q),
                      cirq.CNOT(*q)], '4'),
        cirq.CircuitOperation(cirq.FrozenCircuit(_wrap_in_cop(
            h_cz_y, '5'))).repeat(5).with_tags("preserve_tag"),
        strategy=cirq.InsertStrategy.NEW,
    )

    component_id = 0

    def rewriter_merge_to_circuit_op(
            op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
        nonlocal component_id
        component_id = component_id + 1
        return op.with_tags(f'{component_id}')

    context = cirq.TransformerContext(tags_to_ignore=("ignore", ), deep=True)
    c_new = cirq.merge_k_qubit_unitaries(c_orig,
                                         k=2,
                                         context=context,
                                         rewriter=rewriter_merge_to_circuit_op)
    cirq.testing.assert_same_circuits(c_new, c_expected)

    def _wrap_in_matrix_gate(ops: cirq.OP_TREE):
        op = _wrap_in_cop(ops, 'temp')
        return cirq.MatrixGate(cirq.unitary(op)).on(*op.qubits)

    c_expected_matrix = cirq.Circuit(
        _wrap_in_matrix_gate([h_cz_y, cirq.Y(q[1])]),
        cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
        cirq.CircuitOperation(
            cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
        _wrap_in_matrix_gate([cirq.CNOT(*q), cirq.CNOT(*q)]),
        cirq.CircuitOperation(cirq.FrozenCircuit(
            _wrap_in_matrix_gate(h_cz_y))).repeat(4),
        _wrap_in_matrix_gate([cirq.CNOT(*q),
                              cirq.CZ(*q),
                              cirq.CNOT(*q)]),
        cirq.CircuitOperation(cirq.FrozenCircuit(
            _wrap_in_matrix_gate(h_cz_y))).repeat(5).with_tags("preserve_tag"),
        strategy=cirq.InsertStrategy.NEW,
    )
    c_new_matrix = cirq.merge_k_qubit_unitaries(c_orig, k=2, context=context)
    cirq.testing.assert_same_circuits(c_new_matrix, c_expected_matrix)
Ejemplo n.º 10
0
def test_merge_complex_circuit_preserving_moment_structure():
    q = cirq.LineQubit.range(3)
    c_orig = cirq.Circuit(
        cirq.Moment(cirq.H.on_each(*q)),
        cirq.CNOT(q[0], q[2]),
        cirq.CNOT(*q[0:2]),
        cirq.H(q[0]),
        cirq.CZ(*q[:2]),
        cirq.X(q[0]),
        cirq.Y(q[1]),
        cirq.CNOT(*q[0:2]),
        cirq.CNOT(*q[1:3]).with_tags("ignore"),
        cirq.X(q[0]),
        cirq.Moment(
            cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1]), cirq.Z(q[2])),
        cirq.Moment(cirq.CNOT(*q[:2]), cirq.measure(q[2], key="a")),
        cirq.X(q[0]).with_classical_controls("a"),
        strategy=cirq.InsertStrategy.NEW,
    )
    cirq.testing.assert_has_diagram(
        c_orig,
        '''
0: ───H───@───@───H───@───X───────@─────────────────X───X['ignore']───@───X───
          │   │       │           │                                   │   ║
1: ───H───┼───X───────@───────Y───X───@['ignore']───────Y─────────────X───╫───
          │                           │                                   ║
2: ───H───X───────────────────────────X─────────────────Z─────────────M───╫───
                                                                      ║   ║
a: ═══════════════════════════════════════════════════════════════════@═══^═══
''',
    )
    component_id = 0

    def rewriter_merge_to_circuit_op(
            op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
        nonlocal component_id
        component_id = component_id + 1
        return op.with_tags(f'{component_id}')

    c_new = cirq.merge_k_qubit_unitaries(
        c_orig,
        k=2,
        context=cirq.TransformerContext(tags_to_ignore=("ignore", )),
        rewriter=rewriter_merge_to_circuit_op,
    )
    cirq.testing.assert_has_diagram(
        cirq.drop_empty_moments(c_new),
        '''
      [ 0: ───H───@─── ]        [ 0: ───────@───H───@───X───@───X─── ]                                            [ 0: ───────@─── ]
0: ───[           │    ]────────[           │       │       │        ]──────────────────────X['ignore']───────────[           │    ]────────X───
      [ 2: ───H───X─── ]['1']   [ 1: ───H───X───────@───Y───X─────── ]['2']                                       [ 1: ───Y───X─── ]['4']   ║
      │                         │                                                                                 │                         ║
1: ───┼─────────────────────────#2────────────────────────────────────────────@['ignore']─────────────────────────#2────────────────────────╫───
      │                                                                       │                                                             ║
2: ───#2──────────────────────────────────────────────────────────────────────X─────────────[ 2: ───Z─── ]['3']───M─────────────────────────╫───
                                                                                                                  ║                         ║
a: ═══════════════════════════════════════════════════════════════════════════════════════════════════════════════@═════════════════════════^═══''',
    )

    component_id = 0

    def rewriter_replace_with_decomp(
            op: 'cirq.CircuitOperation') -> 'cirq.OP_TREE':
        nonlocal component_id
        component_id = component_id + 1
        tag = f'{component_id}'
        if len(op.qubits) == 1:
            return [cirq.T(op.qubits[0]).with_tags(tag)]
        one_layer = [op.with_tags(tag) for op in cirq.T.on_each(*op.qubits)]
        two_layer = [cirq.SQRT_ISWAP(*op.qubits).with_tags(tag)]
        return [one_layer, two_layer, one_layer]

    c_new = cirq.merge_k_qubit_unitaries(
        c_orig,
        k=2,
        context=cirq.TransformerContext(tags_to_ignore=("ignore", )),
        rewriter=rewriter_replace_with_decomp,
    )
    cirq.testing.assert_has_diagram(
        cirq.drop_empty_moments(c_new),
        '''
0: ───T['1']───iSwap['1']───T['1']───T['2']───iSwap['2']───T['2']─────────────────X['ignore']───T['4']───iSwap['4']───T['4']───X───
               │                              │                                                          │                     ║
1: ────────────┼─────────────────────T['2']───iSwap^0.5────T['2']───@['ignore']─────────────────T['4']───iSwap^0.5────T['4']───╫───
               │                                                    │                                                          ║
2: ───T['1']───iSwap^0.5────T['1']──────────────────────────────────X─────────────T['3']────────M──────────────────────────────╫───
                                                                                                ║                              ║
a: ═════════════════════════════════════════════════════════════════════════════════════════════@══════════════════════════════^═══''',
    )
Ejemplo n.º 11
0
def test_merge_k_qubit_unitaries_raises():
    with pytest.raises(ValueError,
                       match="k should be greater than or equal to 1"):
        _ = cirq.merge_k_qubit_unitaries(cirq.Circuit())