Example #1
0
def expand_composite(
        circuit: 'cirq.AbstractCircuit',
        *,
        context: Optional['cirq.TransformerContext'] = None,
        no_decomp: Callable[[ops.Operation], bool] = (lambda _: False),
):
    """A transformer that expands composite operations via `cirq.decompose`.

    For each operation in the circuit, this pass examines if the operation can
    be decomposed. If it can be, the operation is cleared out and and replaced
    with its decomposition using a fixed insertion strategy.

    Transformation is applied using `cirq.map_operations_and_unroll`, which preserves the
    moment structure of the input circuit.

    Args:
          circuit: Input circuit to transform.
          context: `cirq.TransformerContext` storing common configurable options for transformers.
          no_decomp: A predicate that determines whether an operation should
                be decomposed or not. Defaults to decomposing everything.
    Returns:
          Copy of the transformed input circuit.
    """
    def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
        if context and context.deep and isinstance(op.untagged,
                                                   circuits.CircuitOperation):
            return op
        return protocols.decompose(op, keep=no_decomp, on_stuck_raise=None)

    return transformer_primitives.map_operations_and_unroll(
        circuit,
        map_func,
        tags_to_ignore=context.tags_to_ignore if context else (),
        deep=context.deep if context else False,
    ).unfreeze(copy=False)
Example #2
0
def _decompose_operations_to_target_gateset(
        circuit: 'cirq.AbstractCircuit',
        *,
        context: Optional['cirq.TransformerContext'] = None,
        gateset: Optional['cirq.Gateset'] = None,
        decomposer: Callable[['cirq.Operation', int],
                             dp.DecomposeResult] = lambda *_: NotImplemented,
        ignore_failures: bool = True,
        tags_to_decompose: Sequence[Hashable] = (),
) -> 'cirq.Circuit':
    """Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`.

    This transformer attempts to decompose every operation `op` in the given circuit to `gateset`
    using `cirq.decompose` protocol with `decomposer` used as an intercepting decomposer. This
    ensures that `op` is recursively decomposed using implicitly defined known decompositions
    (eg: in `_decompose_` magic method on the gaet class) till either `decomposer` knows how to
    decompose the given operation or the given operation belongs to `gateset`.

    Args:
        circuit: Input circuit to transform. It will not be modified.
        context: `cirq.TransformerContext` storing common configurable options for transformers.
        gateset: Target gateset, which the decomposed operations should belong to.
        decomposer: A callable type which accepts an (operation, moment_index) and returns
            - An equivalent `cirq.OP_TREE` implementing `op` using gates from `gateset`.
            - `None` or `NotImplemented` if does not know how to decompose a given `op`.
        ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
            conversion failures raise a ValueError.
        tags_to_decompose: `cirq.CircuitOperation`s tagged with any of `tags_to_decompose` will
            be decomposed even if context.deep is True.

    Returns:
        An equivalent circuit containing gates accepted by `gateset`.

    Raises:
        ValueError: If any input operation fails to convert and `ignore_failures` is False.
    """
    def map_func(op: 'cirq.Operation', moment_index: int):
        if (context and context.deep
                and isinstance(op.untagged, circuits.CircuitOperation)
                and set(op.tags).isdisjoint(tags_to_decompose)):
            return op
        return dp.decompose(
            op,
            intercepting_decomposer=lambda o: decomposer(o, moment_index),
            keep=gateset.validate if gateset else None,
            on_stuck_raise=(None if ignore_failures or gateset is None else
                            _create_on_stuck_raise_error(gateset)),
        )

    return transformer_primitives.map_operations_and_unroll(
        circuit,
        map_func,
        tags_to_ignore=context.tags_to_ignore if context else (),
        deep=context.deep if context else False,
    ).unfreeze(copy=False)
Example #3
0
def merge_k_qubit_unitaries(
    circuit: 'cirq.AbstractCircuit',
    *,
    context: Optional['cirq.TransformerContext'] = None,
    k: int = 0,
    rewriter: Optional[Callable[['cirq.CircuitOperation'],
                                'cirq.OP_TREE']] = None,
) -> 'cirq.Circuit':
    """Merges connected components of unitary operations, acting on <= k qubits.

    Uses rewriter to convert a connected component of unitary operations acting on <= k-qubits
    into a more desirable form. If not specified, connected components are replaced by a single
    `cirq.MatrixGate` containing unitary matrix of the merged component.

    Args:
        circuit: Input circuit to transform. It will not be modified.
        context: `cirq.TransformerContext` storing common configurable options for transformers.
        k: Connected components of unitary operations acting on <= k qubits are merged.
        rewriter: Callable type that takes a `cirq.CircuitOperation`, encapsulating a connected
            component of unitary operations acting on <= k qubits, and produces a `cirq.OP_TREE`.
            Specifies how to merge the connected component into a more desirable form.

    Returns:
        Copy of the transformed input circuit.

    Raises:
        ValueError: If k <= 0
    """
    if k <= 0:
        raise ValueError(f"k should be greater than or equal to 1. Found {k}.")
    merged_circuit_op_tag = "_merged_k_qubit_unitaries_component"

    def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
        if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
            return op
        if rewriter:
            return rewriter(
                cast(circuits.CircuitOperation, op.untagged
                     ) if merged_circuit_op_tag in op.tags else circuits.
                CircuitOperation(circuits.FrozenCircuit(op)))
        return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)

    circuit = transformer_primitives.merge_k_qubit_unitaries_to_circuit_op(
        circuit,
        k=k,
        tags_to_ignore=context.tags_to_ignore if context else (),
        merged_circuit_op_tag=merged_circuit_op_tag,
    )
    return transformer_primitives.map_operations_and_unroll(
        circuit,
        map_func,
        tags_to_ignore=context.tags_to_ignore if context else
        ()).unfreeze(copy=False)
Example #4
0
def _rewrite_merged_k_qubit_unitaries(
    circuit: 'cirq.AbstractCircuit',
    *,
    context: Optional['cirq.TransformerContext'] = None,
    k: int = 0,
    rewriter: Optional[Callable[['cirq.CircuitOperation'], 'cirq.OP_TREE']] = None,
    merged_circuit_op_tag: str = "_merged_k_qubit_unitaries_component",
) -> 'cirq.Circuit':
    deep = context.deep if context else False

    def map_func(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
        op_untagged = op.untagged
        if (
            deep
            and isinstance(op_untagged, circuits.CircuitOperation)
            and merged_circuit_op_tag not in op.tags
        ):
            return op_untagged.replace(
                circuit=_rewrite_merged_k_qubit_unitaries(
                    op_untagged.circuit,
                    context=context,
                    k=k,
                    rewriter=rewriter,
                    merged_circuit_op_tag=merged_circuit_op_tag,
                ).freeze()
            ).with_tags(*op.tags)
        if not (protocols.num_qubits(op) <= k and protocols.has_unitary(op)):
            return op
        if rewriter:
            return rewriter(
                cast(circuits.CircuitOperation, op_untagged)
                if merged_circuit_op_tag in op.tags
                else circuits.CircuitOperation(circuits.FrozenCircuit(op))
            )
        return ops.MatrixGate(protocols.unitary(op)).on(*op.qubits)

    return transformer_primitives.map_operations_and_unroll(
        circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else ()
    ).unfreeze(copy=False)
Example #5
0
def eject_phased_paulis(
    circuit: 'cirq.AbstractCircuit',
    *,
    context: Optional['cirq.TransformerContext'] = None,
    atol: float = 1e-8,
    eject_parameterized: bool = False,
) -> 'cirq.Circuit':
    """Transformer pass to push X, Y, PhasedX & (certain) PhasedXZ gates to the end of the circuit.

    As the gates get pushed, they may absorb Z gates, cancel against other
    X, Y, or PhasedX gates with exponent=1, get merged into measurements (as
    output bit flips), and cause phase kickback operations across CZs (which can
    then be removed by the `cirq.eject_z` transformation).

    `cirq.PhasedXZGate` with `z_exponent=0` (i.e. equivalent to PhasedXPow) or with `x_exponent=0`
    and `axis_phase_exponent=0` (i.e. equivalent to ZPowGate) are also supported.
    To eject `PhasedXZGates` with arbitrary x/z/axis exponents, run
    `cirq.eject_z(cirq.eject_phased_paulis(cirq.eject_z(circuit)))`.

    Args:
        circuit: Input circuit to transform.
        context: `cirq.TransformerContext` storing common configurable options for transformers.
        atol: Maximum absolute error tolerance. The optimization is permitted to simply drop
            negligible combinations gates with a threshold determined by this tolerance.
        eject_parameterized: If True, the optimization will attempt to eject parameterized gates
            as well.  This may result in other gates parameterized by symbolic expressions.
    Returns:
          Copy of the transformed input circuit.
    """
    held_w_phases: Dict[ops.Qid, value.TParamVal] = {}
    tags_to_ignore = set(context.tags_to_ignore) if context else set()

    def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE':
        # Dump if `op` marked with a no compile tag.
        if set(op.tags) & tags_to_ignore:
            return [_dump_held(op.qubits, held_w_phases), op]

        # Collect, phase, and merge Ws.
        w = _try_get_known_phased_pauli(op,
                                        no_symbolic=not eject_parameterized)
        if w is not None:
            return (_potential_cross_whole_w(op, atol, held_w_phases)
                    if single_qubit_decompositions.is_negligible_turn(
                        (w[0] - 1) / 2, atol) else _potential_cross_partial_w(
                            op, held_w_phases))

        affected = [q for q in op.qubits if q in held_w_phases]
        if not affected:
            return op

        # Absorb Z rotations.
        t = _try_get_known_z_half_turns(op,
                                        no_symbolic=not eject_parameterized)
        if t is not None:
            return _absorb_z_into_w(op, held_w_phases)

        # Dump coherent flips into measurement bit flips.
        if isinstance(op.gate, ops.MeasurementGate):
            return _dump_into_measurement(op, held_w_phases)

        # Cross CZs using kickback.
        if _try_get_known_cz_half_turns(
                op, no_symbolic=not eject_parameterized) is not None:
            return (_single_cross_over_cz(op, affected[0]) if len(affected)
                    == 1 else _double_cross_over_cz(op, held_w_phases))

        # Don't know how to handle this situation. Dump the gates.
        return [_dump_held(op.qubits, held_w_phases), op]

    # Map operations and put anything that's still held at the end of the circuit.
    return circuits.Circuit(
        transformer_primitives.map_operations_and_unroll(circuit, map_func),
        _dump_held(held_w_phases.keys(), held_w_phases),
    )
def defer_measurements(
        circuit: 'cirq.AbstractCircuit',
        *,
        context: Optional['cirq.TransformerContext'] = None) -> 'cirq.Circuit':
    """Implements the Deferred Measurement Principle.

    Uses the Deferred Measurement Principle to move all measurements to the
    end of the circuit. All non-terminal measurements are changed to
    conditional quantum gates onto ancilla qubits, and classically controlled
    operations are transformed to quantum controls from those ancilla qubits.
    Finally, measurements of all ancilla qubits are appended to the end of the
    circuit.

    Optimizing deferred measurements is an area of active research, and future
    iterations may contain optimizations that reduce the number of ancilla
    qubits, so one should not depend on the exact shape of the output from this
    function. Only the logical equivalence is guaranteed to remain unchanged.
    Moment and subcircuit structure is not preserved.

    Args:
        circuit: The circuit to transform. It will not be modified.
        context: `cirq.TransformerContext` storing common configurable options
            for transformers.
    Returns:
        A circuit with equivalent logic, but all measurements at the end of the
        circuit.
    Raises:
        ValueError: If sympy-based classical conditions are used, or if
            conditions based on multi-qubit measurements exist. (The latter of
            these is planned to be implemented soon).
        NotImplementedError: When attempting to defer a measurement with a
            confusion map. (https://github.com/quantumlib/Cirq/issues/5482)
    """

    circuit = transformer_primitives.unroll_circuit_op(circuit,
                                                       deep=True,
                                                       tags_to_check=None)
    terminal_measurements = {
        op
        for _, op in find_terminal_measurements(circuit)
    }
    measurement_qubits: Dict['cirq.MeasurementKey',
                             List['_MeasurementQid']] = {}

    def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
        if op in terminal_measurements:
            return op
        gate = op.gate
        if isinstance(gate, ops.MeasurementGate):
            if gate.confusion_map:
                raise NotImplementedError(
                    "Deferring confused measurement is not implemented, but found "
                    f"measurement with key={gate.key} and non-empty confusion map."
                )
            key = value.MeasurementKey.parse_serialized(gate.key)
            targets = [_MeasurementQid(key, q) for q in op.qubits]
            measurement_qubits[key] = targets
            cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
            xs = [
                ops.X(targets[i])
                for i, b in enumerate(gate.full_invert_mask()) if b
            ]
            return cxs + xs
        elif protocols.is_measurement(op):
            return [defer(op, None) for op in protocols.decompose_once(op)]
        elif op.classical_controls:
            controls = []
            for c in op.classical_controls:
                if isinstance(c, value.KeyCondition):
                    if c.key not in measurement_qubits:
                        raise ValueError(
                            f'Deferred measurement for key={c.key} not found.')
                    qubits = measurement_qubits[c.key]
                    if len(qubits) != 1:
                        # TODO: Multi-qubit conditions require
                        # https://github.com/quantumlib/Cirq/issues/4512
                        # Remember to update docstring above once this works.
                        raise ValueError(
                            'Only single qubit conditions are allowed.')
                    controls.extend(qubits)
                else:
                    raise ValueError('Only KeyConditions are allowed.')
            return op.without_classical_controls().controlled_by(
                *controls,
                control_values=[
                    tuple(range(1, q.dimension)) for q in controls
                ])
        return op

    circuit = transformer_primitives.map_operations_and_unroll(
        circuit=circuit,
        map_func=defer,
        tags_to_ignore=context.tags_to_ignore if context else (),
        raise_if_add_qubits=False,
    ).unfreeze()
    for k, qubits in measurement_qubits.items():
        circuit.append(ops.measure(*qubits, key=k))
    return circuit