def test_prune_tensors_no_pruning_took_place(self, mock_device):
        """Test that the _prune_tensors auxiliary method returns
        the original tensor if no observables were pruned."""
        px = qml.PauliX(1)
        obs = px

        def circuit(x):
            return qml.expval(obs)

        qnode = BaseQNode(circuit, mock_device)

        assert qnode._prune_tensors(obs) == px
    def test_prune_tensors(self, mock_device):
        """Test that the _prune_tensors auxiliary method prunes correct for
        a single Identity in the Tensor."""
        px = qml.PauliX(1)
        obs = qml.Identity(0) @ px

        def circuit(x):
            return qml.expval(obs)

        qnode = BaseQNode(circuit, mock_device)

        assert qnode._prune_tensors(obs) == px