def test_error_invalid_callable(self): """Test that an error is raised if the transform is applied to an invalid function""" with pytest.raises( ValueError, match="does not appear to be a valid Python function"): qml.batch_transform(5)
def test_expand_qnode_with_kwarg(self, mocker, perform_expansion): """Test that kwargs are respected in the expansion.""" class MyTransform: """Dummy class to allow spying to work""" def my_transform(self, tape, **kwargs): tape1 = tape.copy() tape2 = tape.copy() return [tape1, tape2], None spy_transform = mocker.spy(MyTransform, "my_transform") transform_fn = qml.batch_transform( MyTransform().my_transform, expand_fn=self.expand_logic_with_kwarg) spy_expand = mocker.spy(transform_fn, "expand_fn") dev = qml.device("default.qubit", wires=2) @functools.partial(transform_fn, perform_expansion=perform_expansion) @qml.qnode(dev) def qnode(x): qml.PhaseShift(0.5, wires=0) return qml.expval(qml.PauliX(0)) qnode(0.2) spy_transform.assert_called() spy_expand.assert_called( ) # The expand_fn of transform_fn always is called input_tape = spy_transform.call_args[0][1] assert len(input_tape.operations) == 1 assert input_tape.operations[0].name == ("RZ" if perform_expansion else "PhaseShift") assert input_tape.operations[0].parameters == [0.5]
def test_expand_fn(self, mocker): """Test that if an expansion function is provided, that the input tape is expanded before being transformed.""" class MyTransform: """Dummy class to allow spying to work""" def my_transform(self, tape): tape1 = tape.copy() tape2 = tape.copy() return [tape1, tape2], None spy_transform = mocker.spy(MyTransform, "my_transform") transform_fn = qml.batch_transform(MyTransform().my_transform, expand_fn=self.phaseshift_expand) with qml.tape.QuantumTape() as tape: qml.PhaseShift(0.5, wires=0) qml.expval(qml.PauliX(0)) spy_expand = mocker.spy(transform_fn, "expand_fn") transform_fn(tape) spy_transform.assert_called() spy_expand.assert_called() input_tape = spy_transform.call_args[0][1] assert len(input_tape.operations) == 1 assert input_tape.operations[0].name == "RZ" assert input_tape.operations[0].parameters == [0.5]
def test_expand_fn_with_kwarg(self, mocker, perform_expansion): """Test that kwargs are respected in the expansion.""" class MyTransform: """Dummy class to allow spying to work""" def my_transform(self, tape, **kwargs): tape1 = tape.copy() tape2 = tape.copy() return [tape1, tape2], None spy_transform = mocker.spy(MyTransform, "my_transform") transform_fn = qml.batch_transform( MyTransform().my_transform, expand_fn=self.expand_logic_with_kwarg) with qml.tape.QuantumTape() as tape: qml.PhaseShift(0.5, wires=0) qml.expval(qml.PauliX(0)) spy_expand = mocker.spy(transform_fn, "expand_fn") transform_fn(tape, perform_expansion=perform_expansion) spy_transform.assert_called() spy_expand.assert_called( ) # The expand_fn of transform_fn always is called input_tape = spy_transform.call_args[0][1] assert len(input_tape.operations) == 1 assert input_tape.operations[0].name == ("RZ" if perform_expansion else "PhaseShift") assert input_tape.operations[0].parameters == [0.5]
def test_not_differentiable(self): """Test that a non-differentiable transform cannot be differentiated""" def my_transform(tape): tape1 = tape.copy() tape2 = tape.copy() return [tape1, tape2], qml.math.sum my_transform = qml.batch_transform(my_transform, differentiable=False) dev = qml.device("default.qubit", wires=2) @my_transform @qml.qnode(dev) def circuit(x): qml.Hadamard(wires=0) qml.RY(x, wires=0) return qml.expval(qml.PauliX(0)) res = circuit(0.5) assert isinstance(res, float) assert not np.allclose(res, 0) with pytest.warns(UserWarning, match="Output seems independent of input"): qml.grad(circuit)(0.5)