Exemple #1
0
 def set_state_vector(self, state: 'cirq.STATE_VECTOR_LIKE'):
     update_state = wave_function.to_valid_state_vector(
         state,
         len(self.qubit_map),
         qid_shape=protocols.qid_shape(self, None),
         dtype=self._dtype)
     np.copyto(self._state_vector, update_state)
Exemple #2
0
 def set_state_vector(self, state: Union[int, np.ndarray]):
     update_state = wave_function.to_valid_state_vector(
         state,
         len(self.qubit_map),
         qid_shape=protocols.qid_shape(self, None),
         dtype=self._dtype)
     np.copyto(self._state_vector, update_state)
Exemple #3
0
    def compute_displays_sweep(
        self,
        program: Union[circuits.Circuit, schedules.Schedule],
        params: Optional[study.Sweepable] = None,
        qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
        initial_state: Union[int, np.ndarray] = 0,
    ) -> List[study.ComputeDisplaysResult]:
        """Computes displays in the supplied Circuit.

        In contrast to `compute_displays`, this allows for sweeping
        over different parameter values.

        Args:
            program: The circuit or schedule to simulate.
            params: Parameters to run with the program.
            qubit_order: Determines the canonical ordering of the qubits used to
                define the order of amplitudes in the wave function.
            initial_state: If an int, the state is set to the computational
                basis state corresponding to this state.
                Otherwise if this is a np.ndarray it is the full initial state.
                In this case it must be the correct size, be normalized (an L2
                norm of 1), and  be safely castable to an appropriate
                dtype for the simulator.

        Returns:
            List of ComputeDisplaysResults for this run, one for each
            possible parameter resolver.
        """
        circuit = (program if isinstance(program, circuits.Circuit) else
                   program.to_circuit())
        param_resolvers = study.to_resolvers(params or study.ParamResolver({}))
        qubit_order = ops.QubitOrder.as_qubit_order(qubit_order)
        qubits = qubit_order.order_for(circuit.all_qubits())

        compute_displays_results = [
        ]  # type: List[study.ComputeDisplaysResult]
        for param_resolver in param_resolvers:
            display_values = {}  # type: ignore

            # Compute the displays in the first Moment
            moment = circuit[0]
            state = wave_function.to_valid_state_vector(initial_state,
                                                        num_qubits=len(qubits))
            qubit_map = {q: i for i, q in enumerate(qubits)}
            _enter_moment_display_values_into_dictionary(
                display_values, moment, state, qubit_order, qubit_map)

            # Compute the displays in the rest of the Moments
            all_step_results = self.simulate_moment_steps(
                circuit, param_resolver, qubit_order, initial_state)
            for step_result, moment in zip(all_step_results, circuit[1:]):
                _enter_moment_display_values_into_dictionary(
                    display_values, moment, step_result.state(), qubit_order,
                    step_result.qubit_map)

            compute_displays_results.append(
                study.ComputeDisplaysResult(params=param_resolver,
                                            display_values=display_values))

        return compute_displays_results
Exemple #4
0
def to_valid_density_matrix(
        density_matrix_rep: Union[int, np.ndarray],
        num_qubits: Optional[int] = None,
        *,  # Force keyword arguments
        qid_shape: Optional[Tuple[int, ...]] = None,
        dtype: Type[np.number] = np.complex64) -> np.ndarray:
    """Verifies the density_matrix_rep is valid and converts it to ndarray form.

    This method is used to support passing a matrix, a vector (wave function),
    or a computational basis state as a representation of a state.

    Args:
        density_matrix_rep: If an numpy array, if it is of rank 2 (a matrix),
            then this is the density matrix. If it is a numpy array of rank 1
            (a vector) then this is a wave function. If this is an int,
            then this is the computation basis state.
        num_qubits: The number of qubits for the density matrix. The
            density_matrix_rep must be valid for this number of qubits.
        qid_shape: The qid shape of the state vector.  Specify this argument
            when using qudits.
        dtype: The numpy dtype of the density matrix, will be used when creating
            the state for a computational basis state (int), or validated
            against if density_matrix_rep is a numpy array.

    Returns:
        A numpy matrix corresponding to the density matrix on the given number
        of qubits.

    Raises:
        ValueError if the density_matrix_rep is not valid.
    """
    qid_shape = _qid_shape_from_args(num_qubits, qid_shape)
    if (isinstance(density_matrix_rep, np.ndarray)
            and density_matrix_rep.ndim == 2):
        if density_matrix_rep.shape != (np.prod(qid_shape, dtype=int), ) * 2:
            raise ValueError(
                'Density matrix was not square and of size 2 ** num_qubit, '
                'instead was {}'.format(density_matrix_rep.shape))
        if not np.allclose(density_matrix_rep,
                           np.transpose(np.conj(density_matrix_rep))):
            raise ValueError('The density matrix is not hermitian.')
        if not np.isclose(np.trace(density_matrix_rep), 1.0):
            raise ValueError(
                'Density matrix did not have trace 1 but instead {}'.format(
                    np.trace(density_matrix_rep)))
        if density_matrix_rep.dtype != dtype:
            raise ValueError(
                'Density matrix had dtype {} but expected {}'.format(
                    density_matrix_rep.dtype, dtype))
        if not np.all(np.linalg.eigvalsh(density_matrix_rep) > -1e-8):
            raise ValueError(
                'The density matrix is not positive semidefinite.')
        return density_matrix_rep

    state_vector = wave_function.to_valid_state_vector(density_matrix_rep,
                                                       len(qid_shape),
                                                       qid_shape=qid_shape,
                                                       dtype=dtype)
    return np.outer(state_vector, np.conj(state_vector))
Exemple #5
0
    def _base_iterator(
        self,
        circuit: circuits.Circuit,
        qubit_order: ops.QubitOrderOrList,
        initial_state: Union[int, np.ndarray],
        perform_measurements: bool = True,
    ) -> Iterator[simulator.StepResult]:
        qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
            circuit.all_qubits())
        num_qubits = len(qubits)
        qubit_map = {q: i for i, q in enumerate(qubits)}
        state = wave_function.to_valid_state_vector(initial_state, num_qubits,
                                                    self._dtype)

        def on_stuck(bad_op: ops.Operation):
            return TypeError(
                "Can't simulate unknown operations that don't specify a "
                "_unitary_ method, a _decompose_ method, or "
                "(_has_unitary_ + _apply_unitary_) methods"
                ": {!r}".format(bad_op))

        def keep(potential_op: ops.Operation) -> bool:
            return (protocols.has_unitary(potential_op)
                    or ops.MeasurementGate.is_measurement(potential_op))

        state = np.reshape(state, (2, ) * num_qubits)
        buffer = np.empty((2, ) * num_qubits, dtype=self._dtype)
        for moment in circuit:
            measurements = collections.defaultdict(
                list)  # type: Dict[str, List[bool]]

            unitary_ops_and_measurements = protocols.decompose(
                moment.operations, keep=keep, on_stuck_raise=on_stuck)

            for op in unitary_ops_and_measurements:
                indices = [qubit_map[qubit] for qubit in op.qubits]
                if ops.MeasurementGate.is_measurement(op):
                    gate = cast(ops.MeasurementGate,
                                cast(ops.GateOperation, op).gate)
                    if perform_measurements:
                        invert_mask = gate.invert_mask or num_qubits * (
                            False, )
                        # Measure updates inline.
                        bits, _ = wave_function.measure_state_vector(
                            state, indices, state)
                        corrected = [
                            bit ^ mask for bit, mask in zip(bits, invert_mask)
                        ]
                        measurements[cast(str, gate.key)].extend(corrected)
                else:
                    result = protocols.apply_unitary(
                        op,
                        args=protocols.ApplyUnitaryArgs(
                            state, buffer, indices))
                    if result is buffer:
                        buffer = state
                    state = result
            yield SimulatorStep(state, measurements, qubit_map, self._dtype)
Exemple #6
0
    def simulate_sweep(
        self,
        program: Union[circuits.Circuit, schedules.Schedule],
        params: study.Sweepable = study.ParamResolver({}),
        qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
        initial_state: Union[int, np.ndarray] = 0,
    ) -> List['SimulationTrialResult']:
        """Simulates the entire supplied Circuit.

        This method returns a result which allows access to the entire
        wave function. In contrast to simulate, this allows for sweeping
        over different parameter values.

        Args:
            program: The circuit or schedule to simulate.
            params: Parameters to run with the program.
            qubit_order: Determines the canonical ordering of the qubits used to
                define the order of amplitudes in the wave function.
            initial_state: If an int, the state is set to the computational
                basis state corresponding to this state.
                Otherwise if this is a np.ndarray it is the full initial state.
                In this case it must be the correct size, be normalized (an L2
                norm of 1), and  be safely castable to an appropriate
                dtype for the simulator.

        Returns:
            List of SimulatorTrialResults for this run, one for each
            possible parameter resolver.
        """
        circuit = (program if isinstance(program, circuits.Circuit)
                   else program.to_circuit())
        param_resolvers = study.to_resolvers(params or study.ParamResolver({}))

        trial_results = []  # type: List[SimulationTrialResult]
        qubit_order = ops.QubitOrder.as_qubit_order(qubit_order)
        for param_resolver in param_resolvers:
            step_result = None
            all_step_results = self.simulate_moment_steps(circuit,
                                                          param_resolver,
                                                          qubit_order,
                                                          initial_state)
            measurements = {}  # type: Dict[str, np.ndarray]
            for step_result in all_step_results:
                for k, v in step_result.measurements.items():
                    measurements[k] = np.array(v, dtype=bool)
            if step_result:
                final_state = step_result.state()
            else:
                # Empty circuit, so final state should be initial state.
                num_qubits = len(qubit_order.order_for(circuit.all_qubits()))
                final_state = wave_function.to_valid_state_vector(initial_state,
                                                                  num_qubits)
            trial_results.append(SimulationTrialResult(
                params=param_resolver,
                measurements=measurements,
                final_state=final_state))

        return trial_results
Exemple #7
0
 def set_state_vector(self, state: Union[int, np.ndarray]):
     update_state = wave_function.to_valid_state_vector(
         state, len(self.qubit_map), self._dtype)
     np.copyto(self._state_vector, update_state)
Exemple #8
0
    def _base_iterator(
        self,
        circuit: circuits.Circuit,
        qubit_order: ops.QubitOrderOrList,
        initial_state: Union[int, np.ndarray],
        perform_measurements: bool = True,
    ) -> Iterator:
        qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
            circuit.all_qubits())
        num_qubits = len(qubits)
        qubit_map = {q: i for i, q in enumerate(qubits)}
        state = wave_function.to_valid_state_vector(initial_state, num_qubits,
                                                    self._dtype)
        if len(circuit) == 0:
            yield SparseSimulatorStep(state, {}, qubit_map, self._dtype)

        def on_stuck(bad_op: ops.Operation):
            return TypeError(
                "Can't simulate unknown operations that don't specify a "
                "_unitary_ method, a _decompose_ method, "
                "(_has_unitary_ + _apply_unitary_) methods,"
                "(_has_mixture_ + _mixture_) methods, or are measurements."
                ": {!r}".format(bad_op))

        def keep(potential_op: ops.Operation) -> bool:
            # The order of this is optimized to call has_xxx methods first.
            return (protocols.has_unitary(potential_op)
                    or protocols.has_mixture(potential_op)
                    or protocols.is_measurement(potential_op))

        data = _StateAndBuffer(state=np.reshape(state, (2, ) * num_qubits),
                               buffer=np.empty((2, ) * num_qubits,
                                               dtype=self._dtype))
        for moment in circuit:
            measurements = collections.defaultdict(
                list)  # type: Dict[str, List[bool]]

            non_display_ops = (op for op in moment
                               if not isinstance(op, (
                                   ops.SamplesDisplay, ops.WaveFunctionDisplay,
                                   ops.DensityMatrixDisplay)))
            unitary_ops_and_measurements = protocols.decompose(
                non_display_ops, keep=keep, on_stuck_raise=on_stuck)

            for op in unitary_ops_and_measurements:
                indices = [qubit_map[qubit] for qubit in op.qubits]
                if protocols.has_unitary(op):
                    self._simulate_unitary(op, data, indices)
                elif protocols.is_measurement(op):
                    # Do measurements second, since there may be mixtures that
                    # operate as measurements.
                    # TODO: support measurement outside the computational basis.
                    if perform_measurements:
                        self._simulate_measurement(op, data, indices,
                                                   measurements, num_qubits)
                elif protocols.has_mixture(op):
                    self._simulate_mixture(op, data, indices)

            yield SparseSimulatorStep(state_vector=data.state,
                                      measurements=measurements,
                                      qubit_map=qubit_map,
                                      dtype=self._dtype)
Exemple #9
0
    def simulate_sweep(
        self,
        program: Union[circuits.Circuit, schedules.Schedule],
        params: study.Sweepable,
        qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT,
        initial_state: Any = None,
    ) -> List['SimulationTrialResult']:
        """Simulates the supplied Circuit or Schedule with Qulacs
        Args:
            program: The circuit or schedule to simulate.
            params: Parameters to run with the program.
            qubit_order: Determines the canonical ordering of the qubits. This
                is often used in specifying the initial state, i.e. the
                ordering of the computational basis states.
            initial_state: The initial state for the simulation. The form of
                this state depends on the simulation implementation.  See
                documentation of the implementing class for details.
        Returns:
            List of SimulationTrialResults for this run, one for each
            possible parameter resolver.
        """
        trial_results = []
        # sweep for each parameters
        resolvers = study.to_resolvers(params)
        for resolver in resolvers:

            # result circuit
            cirq_circuit = protocols.resolve_parameters(program, resolver)
            qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
                cirq_circuit.all_qubits())
            qubit_map = {q: i for i, q in enumerate(qubits)}
            num_qubits = len(qubits)

            # create state
            qulacs_state = self._get_qulacs_state(num_qubits)
            if initial_state is not None:
                cirq_state = wave_function.to_valid_state_vector(
                    initial_state, num_qubits)
                qulacs_state.load(cirq_state)
                del cirq_state

            # create circuit
            qulacs_circuit = qulacs.QuantumCircuit(num_qubits)
            address_to_key = {}
            register_address = 0
            for moment in cirq_circuit:
                operations = moment.operations
                for op in operations:
                    indices = [
                        num_qubits - 1 - qubit_map[qubit]
                        for qubit in op.qubits
                    ]
                    result = self._try_append_gate(op, qulacs_circuit, indices)
                    if result:
                        continue

                    if isinstance(op.gate, ops.ResetChannel):
                        qulacs_circuit.update_quantum_state(qulacs_state)
                        qulacs_state.set_zero_state()
                        qulacs_circuit = qulacs.QuantumCircuit(num_qubits)

                    elif protocols.is_measurement(op):
                        for index in indices:
                            qulacs_circuit.add_gate(
                                qulacs.gate.Measurement(
                                    index, register_address))
                            address_to_key[
                                register_address] = protocols.measurement_key(
                                    op.gate)
                            register_address += 1

                    elif protocols.has_mixture(op):
                        indices.reverse()
                        qulacs_gates = []
                        gate = cast(ops.GateOperation, op).gate
                        channel = protocols.channel(gate)
                        for krauss in channel:
                            krauss = krauss.astype(np.complex128)
                            qulacs_gate = qulacs.gate.DenseMatrix(
                                indices, krauss)
                            qulacs_gates.append(qulacs_gate)
                        qulacs_cptp_map = qulacs.gate.CPTP(qulacs_gates)
                        qulacs.circuit.add_gate(qulacs_cptp_map)

            # perform simulation
            qulacs_circuit.update_quantum_state(qulacs_state)

            # fetch final state and measurement results
            final_state = qulacs_state.get_vector()
            measurements = collections.defaultdict(list)
            for register_index in range(register_address):
                key = address_to_key[register_index]
                value = qulacs_state.get_classical_value(register_index)
                measurements[key].append(value)

            # create result for this parameter
            result = SimulationTrialResult(params=resolver,
                                           measurements=measurements,
                                           final_simulator_state=final_state)
            trial_results.append(result)

            # release memory
            del qulacs_state
            del qulacs_circuit

        return trial_results
    def _base_iterator(
        self,
        circuit: circuits.Circuit,
        qubit_order: ops.QubitOrderOrList,
        initial_state: Union[int, np.ndarray],
        perform_measurements: bool = True,
    ) -> Iterator:
        qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
            circuit.all_qubits())
        num_qubits = len(qubits)
        qubit_map = {q: i for i, q in enumerate(qubits)}
        state = wave_function.to_valid_state_vector(initial_state, num_qubits,
                                                    self._dtype)

        if len(circuit) == 0:
            yield SparseSimulatorStep(state, {}, qubit_map, self._dtype)

        def on_stuck(bad_op: ops.Operation):
            return TypeError(
                "Can't simulate unknown operations that don't specify a "
                "_unitary_ method, a _decompose_ method, "
                "(_has_unitary_ + _apply_unitary_) methods,"
                "(_has_mixture_ + _mixture_) methods, or are measurements."
                ": {!r}".format(bad_op))

        def keep(potential_op: ops.Operation) -> bool:
            # The order of this is optimized to call has_xxx methods first.
            return (protocols.has_unitary(potential_op)
                    or protocols.has_mixture(potential_op)
                    or protocols.is_measurement(potential_op))

        data = _StateAndBuffer(state=np.reshape(state, (2, ) * num_qubits),
                               buffer=np.empty((2, ) * num_qubits,
                                               dtype=self._dtype))

        shape = np.array(data.state).shape

        # Qulacs
        qulacs_flag = 0
        qulacs_state = qulacs.QuantumStateGpu(int(num_qubits))
        qulacs_circuit = qulacs.QuantumCircuit(int(num_qubits))

        for moment in circuit:

            measurements = collections.defaultdict(
                list)  # type: Dict[str, List[bool]]

            non_display_ops = (op for op in moment
                               if not isinstance(op, (
                                   ops.SamplesDisplay, ops.WaveFunctionDisplay,
                                   ops.DensityMatrixDisplay)))

            unitary_ops_and_measurements = protocols.decompose(
                non_display_ops, keep=keep, on_stuck_raise=on_stuck)

            for op in unitary_ops_and_measurements:
                indices = [
                    num_qubits - 1 - qubit_map[qubit] for qubit in op.qubits
                ]
                if protocols.has_unitary(op):

                    # single qubit unitary gates
                    if isinstance(op.gate, ops.pauli_gates._PauliX):
                        qulacs_circuit.add_X_gate(indices[0])
                    elif isinstance(op.gate, ops.pauli_gates._PauliY):
                        qulacs_circuit.add_Y_gate(indices[0])
                    elif isinstance(op.gate, ops.pauli_gates._PauliZ):
                        qulacs_circuit.add_Z_gate(indices[0])
                    elif isinstance(op.gate, ops.common_gates.HPowGate):
                        qulacs_circuit.add_H_gate(indices[0])
                    elif isinstance(op.gate, ops.common_gates.XPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.common_gates.YPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.common_gates.ZPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, circuits.qasm_output.QasmUGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate,
                                    ops.matrix_gates.SingleQubitMatrixGate):
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())

                    # Two Qubit Unitary Gates
                    elif isinstance(op.gate, ops.common_gates.CNotPowGate):
                        qulacs_circuit.add_CNOT_gate(indices[0], indices[1])
                    elif isinstance(op.gate, ops.common_gates.CZPowGate):
                        qulacs_circuit.add_CZ_gate(indices[0], indices[1])
                    elif isinstance(op.gate, ops.common_gates.SwapPowGate):
                        qulacs_circuit.add_SWAP_gate(indices[0], indices[1])
                    elif isinstance(op.gate, ops.common_gates.ISwapPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.parity_gates.XXPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.parity_gates.YYPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.parity_gates.ZZPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate,
                                    ops.matrix_gates.TwoQubitMatrixGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())

                    # Three Qubit Unitary Gates
                    elif isinstance(op.gate, ops.three_qubit_gates.CCXPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.three_qubit_gates.CCZPowGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())
                    elif isinstance(op.gate, ops.three_qubit_gates.CSwapGate):
                        indices.reverse()
                        qulacs_circuit.add_dense_matrix_gate(
                            indices, op._unitary_())

                    qulacs_flag = 1

                elif protocols.is_measurement(op):
                    # Do measurements second, since there may be mixtures that
                    # operate as measurements.
                    # TODO: support measurement outside the computational basis.

                    if perform_measurements:
                        if qulacs_flag == 1:
                            self._simulate_on_qulacs(data, shape, qulacs_state,
                                                     qulacs_circuit)
                            qulacs_flag = 0
                        self._simulate_measurement(op, data, indices,
                                                   measurements, num_qubits)

                elif protocols.has_mixture(op):
                    if qulacs_flag == 1:
                        self._simulate_on_qulacs(data, shape, qulacs_state,
                                                 qulacs_circuit)
                        qulacs_flag = 0
                        qulacs_circuit = qulacs.QuantumCircuit(int(num_qubits))
                    self._simulate_mixture(op, data, indices)

        if qulacs_flag == 1:
            self._simulate_on_qulacs(data, shape, qulacs_state, qulacs_circuit)
            qulacs_flag = 0

        del qulacs_state
        del qulacs_circuit

        yield SparseSimulatorStep(state_vector=data.state,
                                  measurements=measurements,
                                  qubit_map=qubit_map,
                                  dtype=self._dtype)
Exemple #11
0
    def _base_iterator(
        self,
        circuit: circuits.Circuit,
        qubit_order: ops.QubitOrderOrList,
        initial_state: Union[int, np.ndarray],
        perform_measurements: bool = True,
    ) -> Iterator:
        qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(
            circuit.all_qubits())
        num_qubits = len(qubits)
        qubit_map = {q: i for i, q in enumerate(qubits)}
        state = wave_function.to_valid_state_vector(initial_state, num_qubits,
                                                    self._dtype)
        if len(circuit) == 0:
            yield SparseSimulatorStep(state, {}, qubit_map, self._dtype)

        def on_stuck(bad_op: ops.Operation):
            return TypeError(
                "Can't simulate unknown operations that don't specify a "
                "_unitary_ method, a _decompose_ method, or "
                "(_has_unitary_ + _apply_unitary_) methods"
                ": {!r}".format(bad_op))

        def keep(potential_op: ops.Operation) -> bool:
            return (protocols.has_unitary(potential_op)
                    or protocols.is_measurement(potential_op))

        state = np.reshape(state, (2, ) * num_qubits)
        buffer = np.empty((2, ) * num_qubits, dtype=self._dtype)
        for moment in circuit:
            measurements = collections.defaultdict(
                list)  # type: Dict[str, List[bool]]

            non_display_ops = (op for op in moment
                               if not isinstance(op, (
                                   ops.SamplesDisplay, ops.WaveFunctionDisplay,
                                   ops.DensityMatrixDisplay)))
            unitary_ops_and_measurements = protocols.decompose(
                non_display_ops, keep=keep, on_stuck_raise=on_stuck)

            for op in unitary_ops_and_measurements:
                indices = [qubit_map[qubit] for qubit in op.qubits]
                # TODO: Support measurements outside of computational basis.
                # TODO: Support mixtures.
                meas = ops.op_gate_of_type(op, ops.MeasurementGate)
                if meas:
                    if perform_measurements:
                        invert_mask = meas.invert_mask or num_qubits * (
                            False, )
                        # Measure updates inline.
                        bits, _ = wave_function.measure_state_vector(
                            state, indices, state)
                        corrected = [
                            bit ^ mask for bit, mask in zip(bits, invert_mask)
                        ]
                        key = protocols.measurement_key(meas)
                        measurements[key].extend(corrected)
                else:
                    result = protocols.apply_unitary(
                        op,
                        args=protocols.ApplyUnitaryArgs(
                            state, buffer, indices))
                    if result is buffer:
                        buffer = state
                    state = result
            yield SparseSimulatorStep(state_vector=state,
                                      measurements=measurements,
                                      qubit_map=qubit_map,
                                      dtype=self._dtype)