Beispiel #1
0
    def test_defines_correct_capabilities_directly_from_class(self):
        """Test that the device defines the right capabilities"""

        dev = DefaultQubitJax(wires=1)
        cap = dev.capabilities()
        assert cap["supports_reversible_diff"] == False
        assert cap["passthru_interface"] == "jax"
Beispiel #2
0
    def test_full_subsystem(self, mocker):
        """Test applying a state vector to the full subsystem"""
        dev = DefaultQubitJax(wires=["a", "b", "c"])
        state = jnp.array([1, 0, 0, 0, 1, 0, 1, 1]) / 2.0
        state_wires = qml.wires.Wires(["a", "b", "c"])

        spy = mocker.spy(dev, "_scatter")
        dev._apply_state_vector(state=state, device_wires=state_wires)

        assert jnp.all(dev._state.flatten() == state)
        spy.assert_not_called()
Beispiel #3
0
    def test_partial_subsystem(self, mocker):
        """Test applying a state vector to a subset of wires of the full subsystem"""

        dev = DefaultQubitJax(wires=["a", "b", "c"])
        state = jnp.array([1, 0, 1, 0]) / jnp.sqrt(2.0)
        state_wires = qml.wires.Wires(["a", "c"])

        spy = mocker.spy(dev, "_scatter")
        dev._apply_state_vector(state=state, device_wires=state_wires)
        res = jnp.sum(dev._state, axis=(1, )).flatten()

        assert jnp.all(res == state)
        spy.assert_called()