def __init__(self, **kwargs): config = Config() config2 = kwargs.pop("config", None) options = { "gateway_url": "tls://127.0.0.1:%d" % random_port(), "private_url": "http://127.0.0.1:%d" % random_port(), "public_url": "http://127.0.0.1:%d" % random_port(), "db_url": "sqlite:///:memory:", "authenticator_class": "dask_gateway_server.auth.DummyAuthenticator", "cluster_manager_class": ("dask_gateway_server.managers.inprocess.InProcessClusterManager"), } options.update(kwargs) config["DaskGateway"].update(options) if config2: config.merge(config2) self.config = config
async def test_jupyterhub_auth(tmpdir, monkeypatch): from jupyterhub.tests.utils import add_user gateway_address = "http://127.0.0.1:%d" % random_port() jhub_api_token = uuid.uuid4().hex jhub_bind_url = "http://127.0.0.1:%i/@/space%%20word/" % random_port() hub_config = Config() hub_config.JupyterHub.services = [{ "name": "dask-gateway", "url": gateway_address, "api_token": jhub_api_token }] hub_config.JupyterHub.bind_url = jhub_bind_url class MockHub(hub_mocking.MockHub): def init_logging(self): pass hub = MockHub(config=hub_config) # Configure gateway config = Config() config.DaskGateway.public_url = gateway_address + "/services/dask-gateway/" config.DaskGateway.temp_dir = str(tmpdir) config.DaskGateway.authenticator_class = ( "dask_gateway_server.auth.JupyterHubAuthenticator") config.JupyterHubAuthenticator.jupyterhub_api_token = jhub_api_token config.JupyterHubAuthenticator.jupyterhub_api_url = jhub_bind_url + "api/" async with temp_gateway(config=config) as gateway_proc: async with temp_hub(hub): # Create a new jupyterhub user alice, and get the api token u = add_user(hub.db, name="alice") api_token = u.new_api_token() hub.db.commit() # Configure auth with incorrect api token auth = JupyterHubAuth(api_token=uuid.uuid4().hex) async with Gateway( address=gateway_proc.public_urls.connect_url, proxy_address=gateway_proc.gateway_urls.connect_url, asynchronous=True, auth=auth, ) as gateway: # Auth fails with bad token with pytest.raises(Exception): await gateway.list_clusters() # Auth works with correct token auth.api_token = api_token await gateway.list_clusters()
def test_resume_clusters_forbid_in_memory_db(tmpdir): with pytest.raises(ValueError) as exc: DaskGateway( gateway_url="tls://127.0.0.1:%d" % random_port(), private_url="http://127.0.0.1:%d" % random_port(), public_url="http://127.0.0.1:%d" % random_port(), temp_dir=str(tmpdir.join("dask-gateway")), db_url="sqlite://", stop_clusters_on_shutdown=False, authenticator_class="dask_gateway_server.auth.DummyAuthenticator", ) assert "stop_clusters_on_shutdown" in str(exc.value)
def two_proxies(): kwargs = { "public_url": "http://127.0.0.1:%s" % random_port(), "api_url": "http://127.0.0.1:%s" % random_port(), "auth_token": "abcdefg", } try: proxy = WebProxy(**kwargs) proxy2 = WebProxy(externally_managed=True, **kwargs) yield proxy, proxy2 finally: proxy.stop() proxy2.stop()
def test_db_encrypt_keys_required(tmpdir): with pytest.raises(ValueError) as exc: gateway = DaskGateway( gateway_url="tls://127.0.0.1:%d" % random_port(), private_url="http://127.0.0.1:%d" % random_port(), public_url="http://127.0.0.1:%d" % random_port(), temp_dir=str(tmpdir.join("dask-gateway")), db_url="sqlite:///%s" % tmpdir.join("dask_gateway.sqlite"), authenticator_class="dask_gateway_server.auth.DummyAuthenticator", ) gateway.initialize([]) assert "DASK_GATEWAY_ENCRYPT_KEYS" in str(exc.value)
async def test_shutdown_on_startup_error(tmpdir): # A configuration that will cause a failure at runtime (not init time) gateway = DaskGateway( gateway_url="tls://127.0.0.1:%d" % random_port(), private_url="http://127.0.0.1:%d" % random_port(), public_url="http://127.0.0.1:%d" % random_port(), temp_dir=str(tmpdir.join("dask-gateway")), tls_cert=str(tmpdir.join("tls_cert.pem")), authenticator_class="dask_gateway_server.auth.DummyAuthenticator", ) with pytest.raises(SystemExit) as exc: gateway.initialize([]) await gateway.start_or_exit() assert exc.value.code == 1
async def scheduler_proxy(): proxy = SchedulerProxy(public_url="tls://127.0.0.1:%s" % random_port()) try: await proxy.start() yield proxy finally: proxy.stop()
async def inner(self, tmpdir): port = random_port() host = get_ip() api_url = "http://%s:%d" % (host, port) gateway = MockGateway() app = web.Application( [ ( "/clusters/([a-zA-Z0-9-_.]*)/workers/([a-zA-Z0-9-_.]*)", ClusterWorkersHandler, ), ("/clusters/([a-zA-Z0-9-_.]*)/addresses", ClusterRegistrationHandler), ], gateway=gateway, ) manager = self.new_manager(api_url=api_url, temp_dir=str(tmpdir)) server = None try: server = app.listen(port, address=host) await func(self, gateway, manager) finally: if server is not None: server.stop() for cluster in gateway.clusters.values(): worker_states = [w.state for w in cluster.workers.values()] await self.cleanup_cluster(manager, cluster.info, cluster.state, worker_states) await manager.task_pool.close() # Only raise if test didn't fail earlier if gateway.clusters: assert False, "Clusters %r not fully cleaned up" % list( gateway.clusters)
def __init__(self, routes=None, app=None, host="localhost", port=None): self.app = app or web.Application() if routes is not None: self.app.add_routes(routes) self.runner = web.AppRunner(self.app) self.host = host self.port = port or random_port()
async def test_kerberos_auth(tmpdir): config = Config() config.DaskGateway.public_url = "http://master.example.com:%d" % random_port( ) config.DaskGateway.temp_dir = str(tmpdir.join("dask-gateway")) config.DaskGateway.authenticator_class = ( "dask_gateway_server.auth.KerberosAuthenticator") config.KerberosAuthenticator.keytab = KEYTAB_PATH async with temp_gateway(config=config) as gateway_proc: async with Gateway( address=gateway_proc.public_url, proxy_address=gateway_proc.gateway_url, asynchronous=True, auth="kerberos", ) as gateway: kdestroy() with pytest.raises(Exception): await gateway.list_clusters() kinit() await gateway.list_clusters() kdestroy()
async def web_proxy(): proxy = WebProxy(public_url="http://127.0.0.1:%s" % random_port()) try: await proxy.start() yield proxy finally: proxy.stop()
def hello_server(): port = random_port() app = web.Application([(r"/", HelloHandler)]) try: server = app.listen(port) yield "http://127.0.0.1:%d" % port finally: server.stop()
def __init__(self, **kwargs): self._port = random_port() self.proxy = Proxy( address="127.0.0.1:0", prefix="/foobar", gateway_address=f"127.0.0.1:{self._port}", log_level="debug", proxy_status_period=0.5, **kwargs, ) self.app = web.Application() self.runner = web.AppRunner(self.app)
async def test_web_proxy_bad_target(proxy): async with ClientSession() as client: # Add a bad route addr = "http://127.0.0.1:%d" % random_port() await proxy.add_route(kind="PATH", path="/hello", target=addr) proxied_addr = f"http://{proxy.address}{proxy.prefix}/hello" # Route not available async def test_502(): resp = await client.get(proxied_addr) assert resp.status == 502 await with_retries(test_502, 5)
async def inner(self, tmpdir): port = random_port() host = get_ip() async with MockGateway(host, port, str(tmpdir)) as gateway: try: await func(self, gateway) finally: for cluster in gateway.clusters.values(): worker_states = [w.state for w in cluster.workers.values()] await self.cleanup_cluster(cluster.manager, cluster.state, worker_states) # Only raise if test didn't fail earlier if gateway.clusters: assert False, "Clusters %r not fully cleaned up" % list( gateway.clusters)
async def test_web_proxy_bad_target(web_proxy): assert not await web_proxy.get_all_routes() client = AsyncHTTPClient() addr = "http://127.0.0.1:%d" % random_port() proxied_addr = web_proxy.public_url + "/hello" await web_proxy.add_route("/hello", addr) routes = await web_proxy.get_all_routes() assert routes == {"/hello": addr} # Route not available req = HTTPRequest(url=proxied_addr) resp = await client.fetch(req, raise_error=False) assert resp.code == 502
async def ca_and_tls_web_proxy(tmpdir_factory): trustme = pytest.importorskip("trustme") ca = trustme.CA() cert = ca.issue_cert("127.0.0.1") certdir = tmpdir_factory.mktemp("certs") tls_key = str(certdir.join("key.pem")) tls_cert = str(certdir.join("cert.pem")) cert.private_key_pem.write_to_path(tls_key) cert.cert_chain_pems[0].write_to_path(tls_cert) public_url = "https://127.0.0.1:%s" % random_port() proxy = WebProxy(public_url=public_url, tls_key=tls_key, tls_cert=tls_cert) try: await proxy.start() yield ca, proxy finally: proxy.stop()
async def test_jupyterhub_auth_user(monkeypatch): from jupyterhub.tests.utils import add_user jhub_api_token = uuid.uuid4().hex jhub_bind_url = "http://127.0.0.1:%i/@/space%%20word/" % random_port() hub_config = Config() hub_config.JupyterHub.services = [{ "name": "dask-gateway", "api_token": jhub_api_token }] hub_config.JupyterHub.bind_url = jhub_bind_url class MockHub(hub_mocking.MockHub): def init_logging(self): pass hub = MockHub(config=hub_config) # Configure gateway config = configure_dask_gateway(jhub_api_token, jhub_bind_url) async with temp_gateway(config=config) as g: async with temp_hub(hub): # Create a new jupyterhub user alice, and get the api token u = add_user(hub.db, name="alice") api_token = u.new_api_token() hub.db.commit() # Configure auth with incorrect api token auth = JupyterHubAuth(api_token=uuid.uuid4().hex) async with g.gateway_client(auth=auth) as gateway: # Auth fails with bad token with pytest.raises(Exception): await gateway.list_clusters() # Auth works with correct token auth.api_token = api_token await gateway.list_clusters()
async def test_jupyterhub_auth_service(monkeypatch): jhub_api_token = uuid.uuid4().hex jhub_service_token = uuid.uuid4().hex jhub_bind_url = "http://127.0.0.1:%i/@/space%%20word/" % random_port() hub_config = Config() hub_config.JupyterHub.services = [ { "name": "dask-gateway", "api_token": jhub_api_token }, { "name": "any-service", "api_token": jhub_service_token }, ] hub_config.JupyterHub.bind_url = jhub_bind_url class MockHub(hub_mocking.MockHub): def init_logging(self): pass hub = MockHub(config=hub_config) # Configure gateway config = configure_dask_gateway(jhub_api_token, jhub_bind_url) async with temp_gateway(config=config) as g: async with temp_hub(hub): # Configure auth with incorrect api token auth = JupyterHubAuth(api_token=uuid.uuid4().hex) async with g.gateway_client(auth=auth) as gateway: # Auth fails with bad token with pytest.raises(Exception): await gateway.list_clusters() # Auth works with service token auth.api_token = jhub_api_token await gateway.list_clusters()
def __init__(self): self.port = random_port() self.address = "http://127.0.0.1:%d" % self.port self.tasks = set() self.app = web.Application([(r"/", SlowHandler)], task_set=self.tasks)
async def test_gateway_resume_clusters_after_shutdown(tmpdir): db_url = "sqlite:///%s" % tmpdir.join("dask_gateway.sqlite") db_encrypt_keys = [Fernet.generate_key()] config = Config() config.DaskGateway.backend_class = LocalTestingBackend config.LocalTestingBackend.db_url = db_url config.LocalTestingBackend.db_encrypt_keys = db_encrypt_keys config.LocalTestingBackend.stop_clusters_on_shutdown = False config.LocalTestingBackend.cluster_heartbeat_period = 1 config.LocalTestingBackend.check_timeouts_period = 0.5 config.DaskGateway.address = "127.0.0.1:%d" % random_port() config.Proxy.address = "127.0.0.1:%d" % random_port() async with temp_gateway(config=config) as g: async with g.gateway_client() as gateway: cluster1_name = await gateway.submit() async with gateway.connect(cluster1_name) as c: await c.scale(2) cluster2_name = await gateway.submit() async with gateway.connect(cluster2_name): pass async with gateway.new_cluster(): pass active_clusters = { c.name: c for c in g.gateway.backend.db.active_clusters() } # Active clusters are not stopped on shutdown assert active_clusters # Stop 1 worker in cluster 1 worker = list(active_clusters[cluster1_name].workers.values())[0] pid = worker.state["pid"] os.kill(pid, signal.SIGTERM) # Stop cluster 2 pid = active_clusters[cluster2_name].state["pid"] os.kill(pid, signal.SIGTERM) # Restart a new temp_gateway config.LocalTestingBackend.stop_clusters_on_shutdown = True async with temp_gateway(config=config) as g: backend = g.gateway.backend for retry in range(10): try: clusters = list(backend.db.active_clusters()) assert len(clusters) == 1 cluster = clusters[0] assert cluster.name == cluster1_name assert len(cluster.workers) >= 3 assert len(cluster.active_workers()) == 2 break except AssertionError: if retry < 9: await asyncio.sleep(0.5) else: raise # Check that cluster is available and everything still works async with g.gateway_client() as gateway: async with gateway.connect(cluster1_name, shutdown_on_close=True) as cluster: async with cluster.get_client(set_as_default=False) as client: res = await client.submit(lambda x: x + 1, 1) assert res == 2