def test_correct_output_tape_tf(self, ansatz, params):
        """Test that the output is correct when using TensorFlow and
        calling the adjoint metric tensor directly on a tape."""

        tf = pytest.importorskip("tensorflow")

        expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
        t_params = tuple(tf.Variable(p) for p in params)
        dev = qml.device("default.qubit.tf", wires=self.num_wires)

        @qml.qnode(dev, interface="tf")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        with tf.GradientTape() as t:
            circuit(*t_params)
            mt = qml.adjoint_metric_tensor(circuit.qtape, dev)

        expected = qml.math.reshape(expected, qml.math.shape(mt))
        assert qml.math.allclose(mt, expected)

        with tf.GradientTape() as t:
            mt = qml.adjoint_metric_tensor(circuit, hybrid=False)(*t_params)
        assert qml.math.allclose(mt, expected)
    def test_correct_output_tape_jax(self, ansatz, params):
        """Test that the output is correct when using JAX and
        calling the adjoint metric tensor directly on a tape."""

        jax = pytest.importorskip("jax")
        from jax.config import config

        config.update("jax_enable_x64", True)

        expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
        j_params = tuple(jax.numpy.array(p) for p in params)
        dev = qml.device("default.qubit.jax", wires=self.num_wires)

        @qml.qnode(dev, interface="jax")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        circuit(*j_params)
        mt = qml.adjoint_metric_tensor(circuit.qtape, dev)
        expected = qml.math.reshape(expected, qml.math.shape(mt))
        assert qml.math.allclose(mt, expected)

        mt = qml.adjoint_metric_tensor(circuit, hybrid=False)(*j_params)
        assert qml.math.allclose(mt, expected)
    def test_error_finite_shots(self):
        """Test that an error is raised if the device has a finite number of shots set."""
        with qml.tape.JacobianTape() as tape:
            qml.RX(0.2, wires=0)
            qml.RY(1.9, wires=1)
        dev = qml.device("default.qubit", wires=2, shots=1)

        with pytest.raises(ValueError,
                           match="The adjoint method for the metric tensor"):
            qml.adjoint_metric_tensor(tape, device=dev)
    def test_error_wrong_object_passed(self):
        """Test that an error is raised if neither a tape nor a QNode is passed."""
        def ansatz(x, y):
            qml.RX(x, wires=0)
            qml.RY(y, wires=1)

        dev = qml.device("default.qubit", wires=2)

        with pytest.raises(qml.QuantumFunctionError,
                           match="The passed object is not a "):
            qml.adjoint_metric_tensor(ansatz, device=dev)
    def test_correct_output_qnode_tf(self, ansatz, params):
        """Test that the derivative is correct when using TensorFlow and
        calling the adjoint metric tensor on a QNode."""

        tf = pytest.importorskip("tensorflow")

        expected = qml.jacobian(autodiff_metric_tensor(
            ansatz, self.num_wires))(*params)
        t_params = tuple(tf.Variable(p, dtype=tf.float64) for p in params)
        dev = qml.device("default.qubit", wires=self.num_wires)

        @qml.qnode(dev, interface="tf")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        with tf.GradientTape() as t:
            mt = qml.adjoint_metric_tensor(circuit)(*t_params)

        mt_jac = t.jacobian(mt, t_params)
        if isinstance(mt_jac, tuple):
            if not isinstance(expected, tuple) and len(mt_jac) == 1:
                expected = (expected, )
            assert all(
                qml.math.allclose(_mt, _exp)
                for _mt, _exp in zip(mt_jac, expected))
        else:
            assert qml.math.allclose(mt_jac, expected)
    def test_correct_output_qnode_torch(self, ansatz, params):
        """Test that the derivative is correct when using Torch and
        calling the adjoint metric tensor on a QNode."""

        torch = pytest.importorskip("torch")

        expected = qml.jacobian(autodiff_metric_tensor(
            ansatz, self.num_wires))(*params)
        t_params = tuple(
            torch.tensor(p, requires_grad=True, dtype=torch.float64)
            for p in params)
        dev = qml.device("default.qubit", wires=self.num_wires)

        @qml.qnode(dev, interface="torch")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        mt_fn = qml.adjoint_metric_tensor(circuit)
        mt_jac = torch.autograd.functional.jacobian(mt_fn, *t_params)

        if isinstance(mt_jac, tuple):
            assert all(
                qml.math.allclose(_mt, _exp)
                for _mt, _exp in zip(mt_jac, expected))
        else:
            assert qml.math.allclose(mt_jac, expected)
    def test_correct_output_qnode_jax(self, ansatz, params):
        """Test that the derivative is correct when using JAX and
        calling the adjoint metric tensor on a QNode."""

        jax = pytest.importorskip("jax")
        from jax.config import config

        config.update("jax_enable_x64", True)

        expected = qml.jacobian(autodiff_metric_tensor(
            ansatz, self.num_wires))(*params)
        j_params = tuple(jax.numpy.array(p) for p in params)
        dev = qml.device("default.qubit", wires=self.num_wires)

        @qml.qnode(dev, interface="jax")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        mt_fn = qml.adjoint_metric_tensor(circuit, hybrid=True)
        argnums = list(range(len(params)))
        mt_jac = jax.jacobian(mt_fn, argnums=argnums)(*j_params)

        if isinstance(mt_jac, tuple):
            if not isinstance(expected, tuple) and len(mt_jac) == 1:
                expected = (expected, )
            assert all(
                qml.math.allclose(_mt, _exp)
                for _mt, _exp in zip(mt_jac, expected))
        else:
            assert qml.math.allclose(mt_jac, expected)
    def test_autograd_with_other_device(self):
        """Test passing an extra device to the QNode wrapper."""
        ansatz = fubini_ansatz2
        params = fubini_params[2]

        exp_fn = autodiff_metric_tensor(ansatz, self.num_wires)
        expected = qml.jacobian(exp_fn)(*params)
        dev = qml.device("default.qubit", wires=self.num_wires)
        dev2 = qml.device("default.qubit.autograd", wires=self.num_wires)

        @qml.qnode(dev, interface="autograd")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        mt = qml.jacobian(qml.adjoint_metric_tensor(circuit,
                                                    device=dev2))(*params)

        if isinstance(mt, tuple):
            assert all(
                qml.math.allclose(_mt, _exp)
                for _mt, _exp in zip(mt, expected))
        else:
            assert qml.math.allclose(mt, expected)
    def test_correct_output_qnode_jax(self, ansatz, params):
        """Test that the output is correct when using JAX and
        calling the adjoint metric tensor on a QNode."""

        jax = pytest.importorskip("jax")
        from jax.config import config

        config.update("jax_enable_x64", True)

        expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
        j_params = tuple(jax.numpy.array(p) for p in params)
        dev = qml.device("default.qubit", wires=self.num_wires)

        @qml.qnode(dev, interface="jax")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        mt = qml.adjoint_metric_tensor(circuit)(*j_params)

        if isinstance(mt, tuple):
            assert all(
                qml.math.allclose(_mt, _exp)
                for _mt, _exp in zip(mt, expected))
        else:
            assert qml.math.allclose(mt, expected)
    def test_correct_output_tape_autograd(self, ansatz, params):
        """Test that the output is correct when using Autograd and
        calling the adjoint metric tensor directly on a tape."""
        expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
        dev = qml.device("default.qubit.autograd", wires=self.num_wires)

        @qml.qnode(dev, interface="autograd")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        circuit(*params)
        mt = qml.adjoint_metric_tensor(circuit.qtape, dev)
        expected = qml.math.reshape(expected, qml.math.shape(mt))
        assert qml.math.allclose(mt, expected)

        mt = qml.adjoint_metric_tensor(circuit, hybrid=False)(*params)
        assert qml.math.allclose(mt, expected)
    def test_warning_multiple_devices(self):
        """Test that a warning is issued if an ExpvalCost with multiple
        devices is passed."""
        dev1 = qml.device("default.qubit", wires=2)
        dev2 = qml.device("default.qubit", wires=1)
        H = qml.Hamiltonian([0.2, 0.9], [qml.PauliZ(0), qml.PauliY(0)])

        def ansatz(x, wires):
            qml.RX(x, wires=0)

        cost = qml.ExpvalCost(ansatz, H, [dev1, dev2])
        with pytest.warns(UserWarning, match="ExpvalCost was instantiated"):
            mt = qml.adjoint_metric_tensor(cost)
    def test_correct_output_tape_torch(self, ansatz, params):
        """Test that the output is correct when using Torch and
        calling the adjoint metric tensor directly on a tape."""

        torch = pytest.importorskip("torch")

        expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
        t_params = tuple(torch.tensor(p, requires_grad=True) for p in params)
        dev = qml.device("default.qubit.torch", wires=self.num_wires)

        @qml.qnode(dev, interface="torch")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        circuit(*t_params)
        mt = qml.adjoint_metric_tensor(circuit.qtape, dev)
        expected = qml.math.reshape(expected, qml.math.shape(mt))
        assert qml.math.allclose(mt.detach().numpy(), expected)

        mt = qml.adjoint_metric_tensor(circuit, hybrid=False)(*t_params)
        assert qml.math.allclose(mt, expected)
    def test_correct_output_qnode_autograd(self, ansatz, params):
        """Test that the output is correct when using Autograd and
        calling the adjoint metric tensor on a QNode."""
        expected = autodiff_metric_tensor(ansatz, self.num_wires)(*params)
        dev = qml.device("default.qubit", wires=self.num_wires)

        @qml.qnode(dev, interface="autograd")
        def circuit(*params):
            """Circuit with dummy output to create a QNode."""
            ansatz(*params, dev.wires)
            return qml.expval(qml.PauliZ(0))

        mt = qml.adjoint_metric_tensor(circuit)(*params)

        if isinstance(mt, tuple):
            assert all(
                qml.math.allclose(_mt, _exp)
                for _mt, _exp in zip(mt, expected))
        else:
            assert qml.math.allclose(mt, expected)