def test_refresh_access_token_raises_for_expired(self, monkeypatch, code): """Tests that _refresh_access_token raises an error for an expired refresh token.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) class MockResponseForExpired: def __init__(self): self.status_code = code self.mock_post_response = { "status_code": str(code), "code": "Not 200", "detail": "Mock error for refresh.", "meta": "Something went wrong.", } def json(self): return self.mock_post_response mock_response = MockResponseForExpired() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) with pytest.raises(ExpiredRefreshTokenError, match="Invalid refresh token was used."): dev._refresh_access_token()
def test_query_results(self, monkeypatch): """Tests that the ``_query_results`` method sends a request adhering to the Honeywell API specs.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL, retry_delay=0.1) SOME_ACCESS_TOKEN = "XYZ789" monkeypatch.setattr(dev, "get_valid_access_token", lambda: SOME_ACCESS_TOKEN) # set num_calls=1 as the job was already submitted in cases when we get # the result mock_response = MockResponse(num_calls=1) call_history = [] def wrapper(job_endpoint, headers): call_history.append(tuple([job_endpoint, headers])) return mock_response monkeypatch.setattr(requests, "get", wrapper) SOME_JOB_ID = "JOB123" mock_job_data = {"job": SOME_JOB_ID, "status": "not completed!"} res = dev._query_results(mock_job_data) expected_header = { "Authorization": SOME_ACCESS_TOKEN, } assert len(call_history) == 1 job_endpoint, headers = call_history[0] assert job_endpoint == "/".join([dev.hostname, SOME_JOB_ID]) assert headers == expected_header
def test_get_valid_access_token_new_tokens(self, access_token_expiry, refresh_token_expiry, monkeypatch): """Test that the get_valid_access_token returns a new access and refresh tokens by logging in.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) if access_token_expiry: # Set the token to an outdated token dev._access_token = jwt.encode({"exp": access_token_expiry}, "secret") if refresh_token_expiry: # Set the token to an outdated token dev._refresh_token = jwt.encode({"exp": refresh_token_expiry}, "secret") some_token = 1234567 some_refresh_token = 111111 monkeypatch.setattr( dev, "_login", lambda *args, **kwargs: (some_token, some_refresh_token)) monkeypatch.setattr(dev, "save_tokens", lambda *args, **kwargs: None) assert dev.get_valid_access_token() == some_token assert dev._refresh_token == some_refresh_token
def test_get_valid_access_token_new_refresh_token(self, monkeypatch): """Test that the get_valid_access_token manages to get a new refresh token, if an expired refresh token is being used.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) dev._access_token = None dev._refresh_token = "not None" mock_response = MockResponse() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) monkeypatch.setattr(dev, "save_tokens", lambda *args, **kwargs: None) def f(*args, **kwargs): raise ExpiredRefreshTokenError monkeypatch.setattr(dev, "_refresh_access_token", f) some_token = 1234567 some_refresh_token = 111111 monkeypatch.setattr( dev, "_login", lambda *args, **kwargs: (some_token, some_refresh_token)) # The access and refresh token are set according to the output of _login assert dev.get_valid_access_token() == some_token assert dev._refresh_token == some_refresh_token
def test_get_valid_access_token_use_stored(self): """Test that the get_valid_access_token uses a stored token if it exists and it's not expired.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) valid_time = now.replace(now.year + 1) token = jwt.encode({"exp": valid_time}, "secret") dev._access_token = token assert dev.get_valid_access_token() == token
def test_refresh_access_token(self, monkeypatch): """Tests that _refresh_access_token returns an access token for a successful request.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) mock_response = MockResponseWithTokens() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) access_token = dev._refresh_access_token() assert access_token == MOCK_ACCESS_TOKEN
def test_generate_samples(self, results, indices): """Tests that the generate_samples function of HQSDevice provides samples in the correct format expected by PennyLane.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, shots=10, user_email=DUMMY_EMAIL) dev._results = results res = dev.generate_samples() expected_array = np.stack([np.ravel(indices)] * 10) assert res.shape == (dev.shots, dev.num_wires) assert np.all(res == expected_array)
def test_refresh_access_token_raises(self, monkeypatch): """Tests that _refresh_access_token raises an error for a unsuccessful request.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) mock_response = MockResponseUnsuccessfulRequest() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) with pytest.raises(RequestFailedError, match="Failed to get access token"): dev._refresh_access_token()
def test_get_job_retrieval_header(self, monkeypatch): """Tests that the ``get_job_retrieval_header`` method properly returns the correct header.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) SOME_ACCESS_TOKEN = "XYZ789" monkeypatch.setattr(dev, "get_valid_access_token", lambda: SOME_ACCESS_TOKEN) expected = { "Authorization": SOME_ACCESS_TOKEN, } assert dev.get_job_retrieval_header() == expected
def test_login(self, monkeypatch): """Tests that an access token and a refresh token are returned if the _login method was successful.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) mock_response = MockResponseWithTokens() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) monkeypatch.setattr(getpass, "getpass", lambda *args, **kwargs: None) access_token, refresh_token = dev._login() assert access_token == MOCK_ACCESS_TOKEN assert refresh_token == MOCK_REFRESH_TOKEN
def test_login_raises(self, monkeypatch): """Tests that an error is raised if the _login method was unsuccessful.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) mock_response = MockResponseUnsuccessfulRequest() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) monkeypatch.setattr(getpass, "getpass", lambda *args, **kwargs: None) with pytest.raises(RequestFailedError, match="Failed to get access token"): dev._login()
def test_retry_delay(self): """Tests that the ``retry_delay`` property can be set manually.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL, retry_delay=2.5) assert dev.retry_delay == 2.5 dev.retry_delay = 1.0 assert dev.retry_delay == 1.0 with pytest.raises(qml.DeviceError, match="needs to be positive"): dev.retry_delay = -5
def test_save_tokens_no_config_found(self, monkeypatch, tmpdir, tokens, new_dir): """Tests that the save_tokens method correctly defaults to the user config directory when no configuration file exists.""" config_file_name = "config.toml" mock_config = qml.Configuration(config_file_name) if new_dir: # Case when the target directory doesn't exist directory = tmpdir.join("new_dir") else: directory = tmpdir filepath = directory.join(config_file_name) mock_config._user_config_dir = directory # Only the filename is in the filepath: just like when no config file # was found mock_config._filepath = config_file_name monkeypatch.setattr(qml, "default_config", mock_config) HQSDevice(2, machine=DUMMY_MACHINE).save_tokens(*tokens) with open(filepath) as f: configuration_file = toml.load(f) assert configuration_file["honeywell"]["global"][ "access_token"] == tokens[0] if len(tokens) > 1: assert configuration_file["honeywell"]["global"][ "refresh_token"] == tokens[1]
def test_token_is_expired(self, token, expired): """Tests that the token_is_expired method results in expected values.""" token = jwt.encode({"exp": token}, "secret") assert HQSDevice( 2, machine=DUMMY_MACHINE).token_is_expired(token) is expired
def test_query_results_expected_response(self, monkeypatch): """Tests that using the ``_query_results`` method an expected response is gathered.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL, retry_delay=0.01) SOME_ACCESS_TOKEN = "XYZ789" monkeypatch.setattr(dev, "get_valid_access_token", lambda: SOME_ACCESS_TOKEN) mock_response = MockResponse() monkeypatch.setattr(requests, "get", lambda *args, **kwargs: mock_response) SOME_JOB_ID = "JOB123" mock_job_data = {"job": SOME_JOB_ID, "status": "not completed!"} res = dev._query_results(mock_job_data) assert res == mock_response.mock_get_response
def test_get_valid_access_token_using_refresh_token( self, access_token_expiry, monkeypatch): """Test that the get_valid_access_token returns a new access token by refreshing using the refresh token.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) if access_token_expiry: # Set the token to an outdated token dev._access_token = jwt.encode({"exp": access_token_expiry}, "secret") # Set a refresh token with an expiry date in the future dev._refresh_token = jwt.encode({"exp": now.replace(now.year + 1)}, "secret") mock_response = MockResponseWithTokens() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) monkeypatch.setattr(dev, "save_tokens", lambda *args, **kwargs: None) assert dev.get_valid_access_token() == MOCK_ACCESS_TOKEN
def test_set_api_configs(self): """Tests that the ``set_api_configs`` method properly (re)sets the API configs.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) new_user = "******" dev._user = new_user dev.BASE_HOSTNAME = "https://server.someaddress.com" dev.TARGET_PATH = "some/path" dev.set_api_configs() assert dev.hostname == "https://server.someaddress.com/some/path" assert dev._user == new_user
def test_submit_circuit_method(self, monkeypatch): """Tests that the ``_submit_circuit`` method sends a request adhering to the Honeywell API specs.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) SOME_ACCESS_TOKEN = "XYZ789" monkeypatch.setattr(dev, "get_valid_access_token", lambda: SOME_ACCESS_TOKEN) call_history = [] monkeypatch.setattr( requests, "post", lambda hostname, body, headers: call_history.append( tuple([hostname, body, headers])), ) tape, tape_openqasm = get_example_tape_with_qasm() expected_data = { "machine": DUMMY_MACHINE, "language": dev.LANGUAGE, "count": dev.shots, "options": None, } expected_body = {**expected_data, "program": tape_openqasm} expected_header = { "Content-Type": "application/json", "Authorization": SOME_ACCESS_TOKEN, } dev._submit_circuit(tape) assert len(call_history) == 1 hostname, body, headers = call_history[0] assert hostname == dev.hostname assert body == json.dumps(expected_body) assert headers == expected_header
def test_user_not_found_error(self, monkeypatch, tmpdir): """Tests that an error is thrown with the device is created without a valid API token.""" monkeypatch.setenv("HQS_USER", "") monkeypatch.setenv("PENNYLANE_CONF", "") monkeypatch.setattr("os.curdir", tmpdir.join("folder_without_a_config_file")) monkeypatch.setattr( "pennylane.default_config", qml.Configuration("config.toml")) # force loading of config with pytest.raises(ValueError, match="No username for HQS platform found"): HQSDevice(2, machine=DUMMY_MACHINE)._login()
def test_set_api_configs(self): """Tests that the ``set_api_configs`` method properly (re)sets the API configs.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, api_key=SOME_API_KEY) new_api_key = "XYZ789" dev._api_key = new_api_key dev.BASE_HOSTNAME = "https://server.someaddress.com" dev.TARGET_PATH = "some/path" dev.set_api_configs() assert dev.header == { "x-api-key": new_api_key, "User-Agent": "pennylane-honeywell_v{}".format(__version__), } assert dev.hostname == "https://server.someaddress.com/some/path"
def test_invalid_op_exception(self): """Tests whether an exception is raised if the circuit is passed an unsupported operation.""" dev = HQSDevice(2, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) class DummyOp(qml.operation.Operation): num_params = 0 num_wires = 1 par_domain = None @qml.qnode(dev) def circuit(): DummyOp(wires=[0]) return qml.expval(qml.PauliZ(0)) with pytest.raises(qml.DeviceError, match="Gate DummyOp not supported"): circuit()
def test_reset(self): """Tests that the ``reset`` method corretly resets data.""" dev = HQSDevice(3, shots=10, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) dev._results = ["00"] * 10 dev._samples = np.zeros((10, 3)) dev.shots = 11 dev.reset() assert dev._results is None assert dev._samples is None assert dev.shots == 11 # should not be reset
def test_invalid_op_exception(self): """Tests whether an exception is raised if the circuit is passed an unsupported operation.""" dev = HQSDevice(2, machine=DUMMY_MACHINE, api_key=SOME_API_KEY) U = np.array([ [0.6569534 + 0.35252813j, 0.56815252 + 0.34833727j], [-0.56815252 - 0.34833727j, 0.61216718 + 0.42557631j], ]) @qml.qnode(dev) def circuit(): qml.QubitUnitary(U, wires=0) return qml.expval(qml.PauliZ(0)) with pytest.raises(qml.DeviceError, match="Gate QubitUnitary not supported"): circuit()
def test_default_init(self, num_wires, shots, retry_delay): """Tests that the device is properly initialized.""" dev = HQSDevice(num_wires, DUMMY_MACHINE, shots, SOME_API_KEY, retry_delay) assert dev.num_wires == num_wires assert dev.shots == shots assert dev.retry_delay == retry_delay assert dev.analytic == False assert dev.data == { "machine": DUMMY_MACHINE, "language": "OPENQASM 2.0", "priority": "normal", "count": shots, "options": None, } assert dev._results is None assert dev._samples is None assert dev.BASE_HOSTNAME == BASE_HOSTNAME assert API_HEADER_KEY in dev.header.keys() assert dev.header[API_HEADER_KEY] == SOME_API_KEY
def test_default_init(self, num_wires, shots, retry_delay): """Tests that the device is properly initialized.""" dev = HQSDevice(num_wires, DUMMY_MACHINE, shots, user_email=DUMMY_EMAIL, retry_delay=retry_delay) assert dev.num_wires == num_wires assert dev.shots == shots assert dev.retry_delay == retry_delay assert dev.analytic == False assert dev.data == { "machine": DUMMY_MACHINE, "language": "OPENQASM 2.0", "count": shots, "options": None, } assert dev._results is None assert dev._samples is None assert dev.BASE_HOSTNAME == BASE_HOSTNAME assert dev._user == DUMMY_EMAIL
def test_save_tokens(self, monkeypatch, tmpdir, tokens, new_dir): """Tests that the save_tokens method correctly saves to the PennyLane configuration file.""" mock_config = qml.Configuration("config.toml") if new_dir: # Case when the target directory doesn't exist filepath = tmpdir.join("new_dir").join("config.toml") else: filepath = tmpdir.join("config.toml") mock_config._filepath = filepath monkeypatch.setattr(qml, "default_config", mock_config) HQSDevice(2, machine=DUMMY_MACHINE).save_tokens(*tokens) with open(filepath) as f: configuration_file = toml.load(f) assert configuration_file["honeywell"]["global"][ "access_token"] == tokens[0] if len(tokens) > 1: assert configuration_file["honeywell"]["global"][ "refresh_token"] == tokens[1]
def test_get_valid_access_token_using_refresh_token_raises( self, access_token_expiry, monkeypatch): """Test that the get_valid_access_token returns a new access token by refreshing using the refresh token.""" dev = HQSDevice(3, machine=DUMMY_MACHINE, user_email=DUMMY_EMAIL) if access_token_expiry: # Set the token to an outdated token dev._access_token = jwt.encode({"exp": access_token_expiry}, "secret") # Set a refresh token with an expiry date in the future dev._refresh_token = jwt.encode({"exp": now.replace(now.year + 1)}, "secret") mock_response = MockResponseUnsuccessfulRequest() monkeypatch.setattr(requests, "post", lambda *args, **kwargs: mock_response) monkeypatch.setattr(dev, "save_tokens", lambda *args, **kwargs: None) with pytest.raises(RequestFailedError, match="Failed to get access token"): dev.get_valid_access_token()
def test_token_is_expired_raises(self): """Tests that the token_is_expired method raises an error for invalid JWT token.""" with pytest.raises(InvalidJWTError, match="Invalid JWT token"): HQSDevice(2, machine=DUMMY_MACHINE).token_is_expired(Exception)