def test_defines_correct_capabilities_directly_from_class(self): """Test that the device defines the right capabilities""" dev = DefaultQubitJax(wires=1) cap = dev.capabilities() assert cap["supports_reversible_diff"] == False assert cap["passthru_interface"] == "jax"
def test_full_subsystem(self, mocker): """Test applying a state vector to the full subsystem""" dev = DefaultQubitJax(wires=["a", "b", "c"]) state = jnp.array([1, 0, 0, 0, 1, 0, 1, 1]) / 2.0 state_wires = qml.wires.Wires(["a", "b", "c"]) spy = mocker.spy(dev, "_scatter") dev._apply_state_vector(state=state, device_wires=state_wires) assert jnp.all(dev._state.flatten() == state) spy.assert_not_called()
def test_partial_subsystem(self, mocker): """Test applying a state vector to a subset of wires of the full subsystem""" dev = DefaultQubitJax(wires=["a", "b", "c"]) state = jnp.array([1, 0, 1, 0]) / jnp.sqrt(2.0) state_wires = qml.wires.Wires(["a", "c"]) spy = mocker.spy(dev, "_scatter") dev._apply_state_vector(state=state, device_wires=state_wires) res = jnp.sum(dev._state, axis=(1, )).flatten() assert jnp.all(res == state) spy.assert_called()