Exemple #1
0
    def __init__(self, **kwargs):
        config = Config()
        config2 = kwargs.pop("config", None)

        options = {
            "gateway_url":
            "tls://127.0.0.1:%d" % random_port(),
            "private_url":
            "http://127.0.0.1:%d" % random_port(),
            "public_url":
            "http://127.0.0.1:%d" % random_port(),
            "db_url":
            "sqlite:///:memory:",
            "authenticator_class":
            "dask_gateway_server.auth.DummyAuthenticator",
            "cluster_manager_class":
            ("dask_gateway_server.managers.inprocess.InProcessClusterManager"),
        }
        options.update(kwargs)
        config["DaskGateway"].update(options)

        if config2:
            config.merge(config2)

        self.config = config
Exemple #2
0
async def test_jupyterhub_auth(tmpdir, monkeypatch):
    from jupyterhub.tests.utils import add_user

    gateway_address = "http://127.0.0.1:%d" % random_port()
    jhub_api_token = uuid.uuid4().hex
    jhub_bind_url = "http://127.0.0.1:%i/@/space%%20word/" % random_port()

    hub_config = Config()
    hub_config.JupyterHub.services = [{
        "name": "dask-gateway",
        "url": gateway_address,
        "api_token": jhub_api_token
    }]
    hub_config.JupyterHub.bind_url = jhub_bind_url

    class MockHub(hub_mocking.MockHub):
        def init_logging(self):
            pass

    hub = MockHub(config=hub_config)

    # Configure gateway
    config = Config()
    config.DaskGateway.public_url = gateway_address + "/services/dask-gateway/"
    config.DaskGateway.temp_dir = str(tmpdir)
    config.DaskGateway.authenticator_class = (
        "dask_gateway_server.auth.JupyterHubAuthenticator")
    config.JupyterHubAuthenticator.jupyterhub_api_token = jhub_api_token
    config.JupyterHubAuthenticator.jupyterhub_api_url = jhub_bind_url + "api/"

    async with temp_gateway(config=config) as gateway_proc:
        async with temp_hub(hub):
            # Create a new jupyterhub user alice, and get the api token
            u = add_user(hub.db, name="alice")
            api_token = u.new_api_token()
            hub.db.commit()

            # Configure auth with incorrect api token
            auth = JupyterHubAuth(api_token=uuid.uuid4().hex)

            async with Gateway(
                    address=gateway_proc.public_urls.connect_url,
                    proxy_address=gateway_proc.gateway_urls.connect_url,
                    asynchronous=True,
                    auth=auth,
            ) as gateway:

                # Auth fails with bad token
                with pytest.raises(Exception):
                    await gateway.list_clusters()

                # Auth works with correct token
                auth.api_token = api_token
                await gateway.list_clusters()
Exemple #3
0
def test_resume_clusters_forbid_in_memory_db(tmpdir):
    with pytest.raises(ValueError) as exc:
        DaskGateway(
            gateway_url="tls://127.0.0.1:%d" % random_port(),
            private_url="http://127.0.0.1:%d" % random_port(),
            public_url="http://127.0.0.1:%d" % random_port(),
            temp_dir=str(tmpdir.join("dask-gateway")),
            db_url="sqlite://",
            stop_clusters_on_shutdown=False,
            authenticator_class="dask_gateway_server.auth.DummyAuthenticator",
        )

    assert "stop_clusters_on_shutdown" in str(exc.value)
Exemple #4
0
def two_proxies():
    kwargs = {
        "public_url": "http://127.0.0.1:%s" % random_port(),
        "api_url": "http://127.0.0.1:%s" % random_port(),
        "auth_token": "abcdefg",
    }
    try:
        proxy = WebProxy(**kwargs)
        proxy2 = WebProxy(externally_managed=True, **kwargs)
        yield proxy, proxy2
    finally:
        proxy.stop()
        proxy2.stop()
Exemple #5
0
def test_db_encrypt_keys_required(tmpdir):
    with pytest.raises(ValueError) as exc:
        gateway = DaskGateway(
            gateway_url="tls://127.0.0.1:%d" % random_port(),
            private_url="http://127.0.0.1:%d" % random_port(),
            public_url="http://127.0.0.1:%d" % random_port(),
            temp_dir=str(tmpdir.join("dask-gateway")),
            db_url="sqlite:///%s" % tmpdir.join("dask_gateway.sqlite"),
            authenticator_class="dask_gateway_server.auth.DummyAuthenticator",
        )
        gateway.initialize([])

    assert "DASK_GATEWAY_ENCRYPT_KEYS" in str(exc.value)
Exemple #6
0
async def test_shutdown_on_startup_error(tmpdir):
    # A configuration that will cause a failure at runtime (not init time)
    gateway = DaskGateway(
        gateway_url="tls://127.0.0.1:%d" % random_port(),
        private_url="http://127.0.0.1:%d" % random_port(),
        public_url="http://127.0.0.1:%d" % random_port(),
        temp_dir=str(tmpdir.join("dask-gateway")),
        tls_cert=str(tmpdir.join("tls_cert.pem")),
        authenticator_class="dask_gateway_server.auth.DummyAuthenticator",
    )
    with pytest.raises(SystemExit) as exc:
        gateway.initialize([])
        await gateway.start_or_exit()
    assert exc.value.code == 1
Exemple #7
0
async def scheduler_proxy():
    proxy = SchedulerProxy(public_url="tls://127.0.0.1:%s" % random_port())
    try:
        await proxy.start()
        yield proxy
    finally:
        proxy.stop()
Exemple #8
0
    async def inner(self, tmpdir):
        port = random_port()
        host = get_ip()
        api_url = "http://%s:%d" % (host, port)
        gateway = MockGateway()
        app = web.Application(
            [
                (
                    "/clusters/([a-zA-Z0-9-_.]*)/workers/([a-zA-Z0-9-_.]*)",
                    ClusterWorkersHandler,
                ),
                ("/clusters/([a-zA-Z0-9-_.]*)/addresses",
                 ClusterRegistrationHandler),
            ],
            gateway=gateway,
        )
        manager = self.new_manager(api_url=api_url, temp_dir=str(tmpdir))
        server = None
        try:
            server = app.listen(port, address=host)
            await func(self, gateway, manager)
        finally:
            if server is not None:
                server.stop()
            for cluster in gateway.clusters.values():
                worker_states = [w.state for w in cluster.workers.values()]
                await self.cleanup_cluster(manager, cluster.info,
                                           cluster.state, worker_states)
            await manager.task_pool.close()

        # Only raise if test didn't fail earlier
        if gateway.clusters:
            assert False, "Clusters %r not fully cleaned up" % list(
                gateway.clusters)
Exemple #9
0
 def __init__(self, routes=None, app=None, host="localhost", port=None):
     self.app = app or web.Application()
     if routes is not None:
         self.app.add_routes(routes)
     self.runner = web.AppRunner(self.app)
     self.host = host
     self.port = port or random_port()
Exemple #10
0
async def test_kerberos_auth(tmpdir):
    config = Config()
    config.DaskGateway.public_url = "http://master.example.com:%d" % random_port(
    )
    config.DaskGateway.temp_dir = str(tmpdir.join("dask-gateway"))
    config.DaskGateway.authenticator_class = (
        "dask_gateway_server.auth.KerberosAuthenticator")
    config.KerberosAuthenticator.keytab = KEYTAB_PATH

    async with temp_gateway(config=config) as gateway_proc:
        async with Gateway(
                address=gateway_proc.public_url,
                proxy_address=gateway_proc.gateway_url,
                asynchronous=True,
                auth="kerberos",
        ) as gateway:

            kdestroy()

            with pytest.raises(Exception):
                await gateway.list_clusters()

            kinit()

            await gateway.list_clusters()

            kdestroy()
Exemple #11
0
async def web_proxy():
    proxy = WebProxy(public_url="http://127.0.0.1:%s" % random_port())
    try:
        await proxy.start()
        yield proxy
    finally:
        proxy.stop()
Exemple #12
0
def hello_server():
    port = random_port()
    app = web.Application([(r"/", HelloHandler)])
    try:
        server = app.listen(port)
        yield "http://127.0.0.1:%d" % port
    finally:
        server.stop()
Exemple #13
0
 def __init__(self, **kwargs):
     self._port = random_port()
     self.proxy = Proxy(
         address="127.0.0.1:0",
         prefix="/foobar",
         gateway_address=f"127.0.0.1:{self._port}",
         log_level="debug",
         proxy_status_period=0.5,
         **kwargs,
     )
     self.app = web.Application()
     self.runner = web.AppRunner(self.app)
Exemple #14
0
async def test_web_proxy_bad_target(proxy):
    async with ClientSession() as client:
        # Add a bad route
        addr = "http://127.0.0.1:%d" % random_port()
        await proxy.add_route(kind="PATH", path="/hello", target=addr)

        proxied_addr = f"http://{proxy.address}{proxy.prefix}/hello"

        # Route not available
        async def test_502():
            resp = await client.get(proxied_addr)
            assert resp.status == 502

        await with_retries(test_502, 5)
Exemple #15
0
    async def inner(self, tmpdir):
        port = random_port()
        host = get_ip()
        async with MockGateway(host, port, str(tmpdir)) as gateway:
            try:
                await func(self, gateway)
            finally:
                for cluster in gateway.clusters.values():
                    worker_states = [w.state for w in cluster.workers.values()]
                    await self.cleanup_cluster(cluster.manager, cluster.state,
                                               worker_states)

        # Only raise if test didn't fail earlier
        if gateway.clusters:
            assert False, "Clusters %r not fully cleaned up" % list(
                gateway.clusters)
Exemple #16
0
async def test_web_proxy_bad_target(web_proxy):
    assert not await web_proxy.get_all_routes()

    client = AsyncHTTPClient()

    addr = "http://127.0.0.1:%d" % random_port()
    proxied_addr = web_proxy.public_url + "/hello"

    await web_proxy.add_route("/hello", addr)
    routes = await web_proxy.get_all_routes()
    assert routes == {"/hello": addr}

    # Route not available
    req = HTTPRequest(url=proxied_addr)
    resp = await client.fetch(req, raise_error=False)
    assert resp.code == 502
Exemple #17
0
async def ca_and_tls_web_proxy(tmpdir_factory):
    trustme = pytest.importorskip("trustme")
    ca = trustme.CA()
    cert = ca.issue_cert("127.0.0.1")

    certdir = tmpdir_factory.mktemp("certs")
    tls_key = str(certdir.join("key.pem"))
    tls_cert = str(certdir.join("cert.pem"))

    cert.private_key_pem.write_to_path(tls_key)
    cert.cert_chain_pems[0].write_to_path(tls_cert)

    public_url = "https://127.0.0.1:%s" % random_port()

    proxy = WebProxy(public_url=public_url, tls_key=tls_key, tls_cert=tls_cert)
    try:
        await proxy.start()
        yield ca, proxy
    finally:
        proxy.stop()
Exemple #18
0
async def test_jupyterhub_auth_user(monkeypatch):
    from jupyterhub.tests.utils import add_user

    jhub_api_token = uuid.uuid4().hex
    jhub_bind_url = "http://127.0.0.1:%i/@/space%%20word/" % random_port()

    hub_config = Config()
    hub_config.JupyterHub.services = [{
        "name": "dask-gateway",
        "api_token": jhub_api_token
    }]
    hub_config.JupyterHub.bind_url = jhub_bind_url

    class MockHub(hub_mocking.MockHub):
        def init_logging(self):
            pass

    hub = MockHub(config=hub_config)

    # Configure gateway
    config = configure_dask_gateway(jhub_api_token, jhub_bind_url)

    async with temp_gateway(config=config) as g:
        async with temp_hub(hub):
            # Create a new jupyterhub user alice, and get the api token
            u = add_user(hub.db, name="alice")
            api_token = u.new_api_token()
            hub.db.commit()

            # Configure auth with incorrect api token
            auth = JupyterHubAuth(api_token=uuid.uuid4().hex)

            async with g.gateway_client(auth=auth) as gateway:
                # Auth fails with bad token
                with pytest.raises(Exception):
                    await gateway.list_clusters()

                # Auth works with correct token
                auth.api_token = api_token
                await gateway.list_clusters()
Exemple #19
0
async def test_jupyterhub_auth_service(monkeypatch):
    jhub_api_token = uuid.uuid4().hex
    jhub_service_token = uuid.uuid4().hex
    jhub_bind_url = "http://127.0.0.1:%i/@/space%%20word/" % random_port()

    hub_config = Config()
    hub_config.JupyterHub.services = [
        {
            "name": "dask-gateway",
            "api_token": jhub_api_token
        },
        {
            "name": "any-service",
            "api_token": jhub_service_token
        },
    ]
    hub_config.JupyterHub.bind_url = jhub_bind_url

    class MockHub(hub_mocking.MockHub):
        def init_logging(self):
            pass

    hub = MockHub(config=hub_config)

    # Configure gateway
    config = configure_dask_gateway(jhub_api_token, jhub_bind_url)

    async with temp_gateway(config=config) as g:
        async with temp_hub(hub):
            # Configure auth with incorrect api token
            auth = JupyterHubAuth(api_token=uuid.uuid4().hex)
            async with g.gateway_client(auth=auth) as gateway:
                # Auth fails with bad token
                with pytest.raises(Exception):
                    await gateway.list_clusters()

                # Auth works with service token
                auth.api_token = jhub_api_token
                await gateway.list_clusters()
Exemple #20
0
 def __init__(self):
     self.port = random_port()
     self.address = "http://127.0.0.1:%d" % self.port
     self.tasks = set()
     self.app = web.Application([(r"/", SlowHandler)], task_set=self.tasks)
async def test_gateway_resume_clusters_after_shutdown(tmpdir):
    db_url = "sqlite:///%s" % tmpdir.join("dask_gateway.sqlite")
    db_encrypt_keys = [Fernet.generate_key()]

    config = Config()
    config.DaskGateway.backend_class = LocalTestingBackend
    config.LocalTestingBackend.db_url = db_url
    config.LocalTestingBackend.db_encrypt_keys = db_encrypt_keys
    config.LocalTestingBackend.stop_clusters_on_shutdown = False
    config.LocalTestingBackend.cluster_heartbeat_period = 1
    config.LocalTestingBackend.check_timeouts_period = 0.5
    config.DaskGateway.address = "127.0.0.1:%d" % random_port()
    config.Proxy.address = "127.0.0.1:%d" % random_port()

    async with temp_gateway(config=config) as g:
        async with g.gateway_client() as gateway:
            cluster1_name = await gateway.submit()
            async with gateway.connect(cluster1_name) as c:
                await c.scale(2)

            cluster2_name = await gateway.submit()
            async with gateway.connect(cluster2_name):
                pass

            async with gateway.new_cluster():
                pass

    active_clusters = {
        c.name: c
        for c in g.gateway.backend.db.active_clusters()
    }

    # Active clusters are not stopped on shutdown
    assert active_clusters

    # Stop 1 worker in cluster 1
    worker = list(active_clusters[cluster1_name].workers.values())[0]
    pid = worker.state["pid"]
    os.kill(pid, signal.SIGTERM)

    # Stop cluster 2
    pid = active_clusters[cluster2_name].state["pid"]
    os.kill(pid, signal.SIGTERM)

    # Restart a new temp_gateway
    config.LocalTestingBackend.stop_clusters_on_shutdown = True
    async with temp_gateway(config=config) as g:
        backend = g.gateway.backend
        for retry in range(10):
            try:
                clusters = list(backend.db.active_clusters())
                assert len(clusters) == 1

                cluster = clusters[0]

                assert cluster.name == cluster1_name
                assert len(cluster.workers) >= 3
                assert len(cluster.active_workers()) == 2
                break
            except AssertionError:
                if retry < 9:
                    await asyncio.sleep(0.5)
                else:
                    raise
        # Check that cluster is available and everything still works
        async with g.gateway_client() as gateway:
            async with gateway.connect(cluster1_name,
                                       shutdown_on_close=True) as cluster:
                async with cluster.get_client(set_as_default=False) as client:
                    res = await client.submit(lambda x: x + 1, 1)
                    assert res == 2