def test_construct_subcircuit(self):
        """Test correct subcircuits constructed"""
        dev = qml.device("default.qubit", wires=2)

        def circuit(a, b, c):
            qml.RX(a, wires=0)
            qml.RY(b, wires=0)
            qml.CNOT(wires=[0, 1])
            qml.PhaseShift(c, wires=1)
            return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(1))

        circuit = QubitQNode(circuit, dev)
        circuit.metric_tensor([1, 1, 1], only_construct=True)
        res = circuit._metric_tensor_subcircuits

        # first parameter subcircuit
        assert len(res[(0, )]["queue"]) == 0
        assert res[(0, )]["scale"] == [-0.5]
        assert isinstance(res[(0, )]["observable"][0], qml.PauliX)

        # second parameter subcircuit
        assert len(res[(1, )]["queue"]) == 1
        assert res[(1, )]["scale"] == [-0.5]
        assert isinstance(res[(1, )]["queue"][0], qml.RX)
        assert isinstance(res[(1, )]["observable"][0], qml.PauliY)

        # third parameter subcircuit
        assert len(res[(2, )]["queue"]) == 3
        assert res[(2, )]["scale"] == [1]
        assert isinstance(res[(2, )]["queue"][0], qml.RX)
        assert isinstance(res[(2, )]["queue"][1], qml.RY)
        assert isinstance(res[(2, )]["queue"][2], qml.CNOT)
        assert isinstance(res[(2, )]["observable"][0], qml.Hermitian)
        assert np.all(res[(
            2, )]["observable"][0].params[0] == qml.PhaseShift.generator[0])
    def test_evaluate_subcircuits(self, tol):
        """Test subcircuits evaluate correctly"""
        dev = qml.device("default.qubit", wires=2)

        def circuit(a, b, c):
            qml.RX(a, wires=0)
            qml.RY(b, wires=0)
            qml.CNOT(wires=[0, 1])
            qml.PhaseShift(c, wires=1)
            return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(1))

        circuit = QubitQNode(circuit, dev)

        a = 0.432
        b = 0.12
        c = -0.432

        # evaluate subcircuits
        circuit.metric_tensor((a, b, c))

        # first parameter subcircuit
        res = circuit._metric_tensor_subcircuits[(0, )]["result"]
        expected = 0.25
        assert np.allclose(res, expected, atol=tol, rtol=0)

        # second parameter subcircuit
        res = circuit._metric_tensor_subcircuits[(1, )]["result"]
        expected = np.cos(a)**2 / 4
        assert np.allclose(res, expected, atol=tol, rtol=0)

        # third parameter subcircuit
        res = circuit._metric_tensor_subcircuits[(2, )]["result"]
        expected = (3 - 2 * np.cos(a)**2 * np.cos(2 * b) - np.cos(2 * a)) / 16
        assert np.allclose(res, expected, atol=tol, rtol=0)
    def test_no_generator(self):
        """Test exception is raised if subcircuit contains an
        operation with no generator"""
        dev = qml.device("default.qubit", wires=1)

        def circuit(a):
            qml.Rot(a, 0, 0, wires=0)
            return qml.expval(qml.PauliX(0))

        circuit = QubitQNode(circuit, dev)

        with pytest.raises(QuantumFunctionError,
                           match="has no defined generator"):
            circuit.metric_tensor([1], only_construct=True)
    def test_evaluate_diag_metric_tensor(self, tol):
        """Test that a diagonal metric tensor evaluates correctly"""
        dev = qml.device("default.qubit", wires=2)

        def circuit(a, b, c):
            qml.RX(a, wires=0)
            qml.RY(b, wires=0)
            qml.CNOT(wires=[0, 1])
            qml.PhaseShift(c, wires=1)
            return qml.expval(qml.PauliX(0)), qml.expval(qml.PauliX(1))

        circuit = QubitQNode(circuit, dev)

        a = 0.432
        b = 0.12
        c = -0.432

        # evaluate metric tensor
        g = circuit.metric_tensor((a, b, c))

        # check that the metric tensor is correct
        expected = (np.array([
            1,
            np.cos(a)**2,
            (3 - 2 * np.cos(a)**2 * np.cos(2 * b) - np.cos(2 * a)) / 4
        ]) / 4)
        assert np.allclose(g, np.diag(expected), atol=tol, rtol=0)
    def test_generator_no_expval(self, monkeypatch):
        """Test exception is raised if subcircuit contains an
        operation with generator object that is not an observable"""
        dev = qml.device("default.qubit", wires=1)

        def circuit(a):
            qml.RX(a, wires=0)
            return qml.expval(qml.PauliX(0))

        circuit = QubitQNode(circuit, dev)

        with monkeypatch.context() as m:
            m.setattr("pennylane.RX.generator", [qml.RX, 1])

            with pytest.raises(QuantumFunctionError,
                               match="no corresponding observable"):
                circuit.metric_tensor([1], only_construct=True)
    def test_construct_subcircuit_layers(self):
        """Test correct subcircuits constructed
        when a layer structure exists"""
        dev = qml.device("default.qubit", wires=3)

        def circuit(params):
            # section 1
            qml.RX(params[0], wires=0)
            # section 2
            qml.RY(params[1], wires=0)
            qml.CNOT(wires=[0, 1])
            qml.CNOT(wires=[1, 2])
            # section 3
            qml.RX(params[2], wires=0)
            qml.RY(params[3], wires=1)
            qml.RZ(params[4], wires=2)
            qml.CNOT(wires=[0, 1])
            qml.CNOT(wires=[1, 2])
            # section 4
            qml.RX(params[5], wires=0)
            qml.RY(params[6], wires=1)
            qml.RZ(params[7], wires=2)
            qml.CNOT(wires=[0, 1])
            qml.CNOT(wires=[1, 2])
            return qml.expval(qml.PauliX(0)), qml.expval(
                qml.PauliX(1)), qml.expval(qml.PauliX(2))

        circuit = QubitQNode(circuit, dev)

        params = np.ones([8])
        circuit.metric_tensor([params], only_construct=True)
        res = circuit._metric_tensor_subcircuits

        # this circuit should split into 4 independent
        # sections or layers when constructing subcircuits
        assert len(res) == 4

        # first layer subcircuit
        layer = res[(0, )]
        assert len(layer["queue"]) == 0
        assert len(layer["observable"]) == 1
        assert isinstance(layer["observable"][0], qml.PauliX)

        # second layer subcircuit
        layer = res[(1, )]
        assert len(layer["queue"]) == 1
        assert len(layer["observable"]) == 1
        assert isinstance(layer["queue"][0], qml.RX)
        assert isinstance(layer["observable"][0], qml.PauliY)

        # third layer subcircuit
        layer = res[(2, 3, 4)]
        assert len(layer["queue"]) == 4
        assert len(layer["observable"]) == 3
        assert isinstance(layer["queue"][0], qml.RX)
        assert isinstance(layer["queue"][1], qml.RY)
        assert isinstance(layer["queue"][2], qml.CNOT)
        assert isinstance(layer["queue"][3], qml.CNOT)
        assert isinstance(layer["observable"][0], qml.PauliX)
        assert isinstance(layer["observable"][1], qml.PauliY)
        assert isinstance(layer["observable"][2], qml.PauliZ)

        # fourth layer subcircuit
        layer = res[(5, 6, 7)]
        assert len(layer["queue"]) == 9
        assert len(layer["observable"]) == 3
        assert isinstance(layer["queue"][0], qml.RX)
        assert isinstance(layer["queue"][1], qml.RY)
        assert isinstance(layer["queue"][2], qml.CNOT)
        assert isinstance(layer["queue"][3], qml.CNOT)
        assert isinstance(layer["queue"][4], qml.RX)
        assert isinstance(layer["queue"][5], qml.RY)
        assert isinstance(layer["queue"][6], qml.RZ)
        assert isinstance(layer["queue"][7], qml.CNOT)
        assert isinstance(layer["queue"][8], qml.CNOT)
        assert isinstance(layer["observable"][0], qml.PauliX)
        assert isinstance(layer["observable"][1], qml.PauliY)
        assert isinstance(layer["observable"][2], qml.PauliZ)