def test_multi_thread(self): """Test that multi-threaded queuing works correctly""" n_qubits = 4 n_batches = 5 dev = qml.device("default.qubit", wires=n_qubits) def circuit(inputs, weights): for index, input in enumerate(inputs): qml.RY(input, wires=index) for index in range(n_qubits - 1): qml.CNOT(wires=(index, index + 1)) for index, weight in enumerate(weights): qml.RX(weight, wires=index) return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)] weight_shapes = {"weights": (n_qubits)} try: qnode = QNodeCollection( [QNode(circuit, dev) for _ in range(n_batches)]) except Exception as e: pytest.fail("QNodeCollection cannot be instantiated") x = np.random.rand(n_qubits).astype(np.float64) p = np.random.rand(weight_shapes["weights"]).astype(np.float64) try: for _ in range(10): qnode(x, p, parallel=True) except: pytest.fail("Multi-threading on QuantumTape failed")
def test_error_backprop_wrong_interface(self, interface, tol): """Tests that an error is raised if diff_method='backprop' but not using the Autograd interface""" dev = qml.device("default.qubit.autograd", wires=1) def circuit(x, w=None): qml.RZ(x, wires=w) return qml.expval(qml.PauliX(w)) with pytest.raises( qml.QuantumFunctionError, match="default.qubit.autograd only supports diff_method='backprop' when using the autograd interface", ): qml.qnode(dev, diff_method="backprop", interface=interface)(circuit)
def test_error_backprop_wrong_interface(self, interface, tol): """Tests that an error is raised if diff_method='backprop' but not using the Jax interface""" dev = qml.device("default.qubit.jax", wires=1) def circuit(x, w=None): qml.RZ(x, wires=w) return qml.expval(qml.PauliX(w)) error_type = qml.QuantumFunctionError if qml.tape_mode_active( ) else ValueError with pytest.raises( error_type, match= "default.qubit.jax only supports diff_method='backprop' when using the jax interface", ): qml.qnode(dev, diff_method="backprop", interface=interface)(circuit)
def test_operator_all_wires(self, monkeypatch, tol): """Test that an operator that must act on all wires does, or raises an error.""" monkeypatch.setattr(qml.RX, "num_wires", qml.operation.AllWires) def circuit(x): qml.RX(x, wires=0) return qml.expval(qml.PauliZ(0)) dev = qml.device("default.qubit", wires=2) qnode = QNode(circuit, dev) with pytest.raises(qml.QuantumFunctionError, match="Operator RX must act on all wires"): qnode(0.5) dev = qml.device("default.qubit", wires=1) qnode = QNode(circuit, dev) assert np.allclose(qnode(0.5), np.cos(0.5), atol=tol, rtol=0)
def test_no_jax_interface_applied(self): """Tests that the JAX interface is not applied and no error is raised if qml.probs is used with the Jax interface when diff_method='backprop' When the JAX interface is applied, we can only get the expectation value and the variance of a QNode. """ dev = qml.device("default.qubit.jax", wires=1, shots=None) def circuit(): return qml.probs(wires=0) qnode = qml.qnode(dev, diff_method="backprop", interface="jax")(circuit) assert jnp.allclose(qnode(), jnp.array([1, 0]))