Ejemplo n.º 1
0
def test_close_loop_sync(with_own_loop):
    loop_runner = loop = None

    # Setup simple cluster with one threaded worker.
    # Complex setup is not required here since we test only IO loop teardown.
    cluster_params = dict(n_workers=1, dashboard_address=None, processes=False)

    loops_before = LoopRunner._all_loops.copy()

    # Start own loop or use current thread's one.
    if with_own_loop:
        loop_runner = LoopRunner()
        loop_runner.start()
        loop = loop_runner.loop

    with LocalCluster(loop=loop, **cluster_params) as cluster:
        with Client(cluster, loop=loop) as client:
            client.run(max, 1, 2)

    # own loop must be explicitly stopped.
    if with_own_loop:
        loop_runner.stop()

    # Internal loops registry must the same as before cluster running.
    # This means loop runners in LocalCluster and Client correctly stopped.
    # See LoopRunner._stop_unlocked().
    assert loops_before == LoopRunner._all_loops
Ejemplo n.º 2
0
async def test_loop_runner_gen():
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is IOLoop.current()
    assert not runner.is_started()
    await asyncio.sleep(0.01)
    runner.start()
    assert runner.is_started()
    await asyncio.sleep(0.01)
    runner.stop()
    assert not runner.is_started()
    await asyncio.sleep(0.01)
Ejemplo n.º 3
0
def test_loop_runner_gen():
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is IOLoop.current()
    assert not runner.is_started()
    yield gen.sleep(0.01)
    runner.start()
    assert runner.is_started()
    yield gen.sleep(0.01)
    runner.stop()
    assert not runner.is_started()
    yield gen.sleep(0.01)
Ejemplo n.º 4
0
class ContextCluster(Cluster):
    def __init__(self, asynchronous=False, loop=None):
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        super().__init__(asynchronous=asynchronous)

        if not self.asynchronous:
            self._loop_runner.start()
            self.sync(self._start)

    def __enter__(self):
        if self.status != "running":
            raise ValueError(
                f"Expected status 'running', found '{self.status}'")
        return self

    def __exit__(self, typ, value, traceback):
        self.close()
        self._loop_runner.stop()
Ejemplo n.º 5
0
def test_two_loop_runners(loop_in_thread):
    # Loop runners tied to the same loop should cooperate

    # ABCCBA
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    b.start()
    assert_running(loop)
    c.start()
    assert_running(loop)
    c.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    a.stop()
    assert_not_running(loop)

    # ABCABC
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    b.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    c.start()
    assert_running(loop)
    a.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    c.stop()
    assert_not_running(loop)

    # Explicit loop, already started
    a = LoopRunner(loop=loop_in_thread)
    b = LoopRunner(loop=loop_in_thread)
    assert_running(loop_in_thread)
    a.start()
    assert_running(loop_in_thread)
    b.start()
    assert_running(loop_in_thread)
    a.stop()
    assert_running(loop_in_thread)
    b.stop()
    assert_running(loop_in_thread)
Ejemplo n.º 6
0
def test_loop_runner(loop_in_thread):
    # Implicit loop
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner()
    assert runner.loop not in (loop, loop_in_thread)
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop
    loop = IOLoop()
    runner = LoopRunner(loop=loop)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(loop)
    runner.start()
    assert runner.is_started()
    assert_running(loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(loop)

    # Explicit loop, already started
    runner = LoopRunner(loop=loop_in_thread)
    assert not runner.is_started()
    assert_running(loop_in_thread)
    runner.start()
    assert runner.is_started()
    assert_running(loop_in_thread)
    runner.stop()
    assert not runner.is_started()
    assert_running(loop_in_thread)

    # Implicit loop, asynchronous=True
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop, asynchronous=True
    loop = IOLoop()
    runner = LoopRunner(loop=loop, asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)
Ejemplo n.º 7
0
class YarnCluster(object):
    """Start a Dask cluster on YARN.

    You can define default values for this in Dask's ``yarn.yaml``
    configuration file. See http://docs.dask.org/en/latest/configuration.html
    for more information.

    Parameters
    ----------
    environment : str, optional
        The Python environment to use. Can be one of the following:

          - A path to an archived Python environment
          - A path to a conda environment, specified as `conda:///...`
          - A path to a virtual environment, specified as `venv:///...`
          - A path to a python executable, specifed as `python:///...`

        Note that if not an archive, the paths specified must be valid on all
        nodes in the cluster.
    n_workers : int, optional
        The number of workers to initially start.
    worker_vcores : int, optional
        The number of virtual cores to allocate per worker.
    worker_memory : str, optional
        The amount of memory to allocate per worker. Accepts a unit suffix
        (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest MiB.
    worker_restarts : int, optional
        The maximum number of worker restarts to allow before failing the
        application. Default is unlimited.
    worker_env : dict, optional
        A mapping of environment variables to their values. These will be set
        in the worker containers before starting the dask workers.
    scheduler_vcores : int, optional
        The number of virtual cores to allocate per scheduler.
    scheduler_memory : str, optional
        The amount of memory to allocate to the scheduler. Accepts a unit
        suffix (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest
        MiB.
    deploy_mode : {'remote', 'local'}, optional
        The deploy mode to use. If ``'remote'``, the scheduler will be deployed
        in a YARN container. If ``'local'``, the scheduler will run locally,
        which can be nice for debugging. Default is ``'remote'``.
    name : str, optional
        The application name.
    queue : str, optional
        The queue to deploy to.
    tags : sequence, optional
        A set of strings to use as tags for this application.
    user : str, optional
        The user to submit the application on behalf of. Default is the current
        user - submitting as a different user requires user permissions, see
        the YARN documentation for more information.
    host : str, optional
        Host address on which the scheduler will listen. Only used if
        ``deploy_mode='local'``. Defaults to ``'0.0.0.0'``.
    port : int, optional
        The port on which the scheduler will listen. Only used if
        ``deploy_mode='local'``. Defaults to ``0`` for a random port.
    dashboard_address : str
        Address on which to the dashboard server will listen. Only used if
        ``deploy_mode='local'``. Defaults to ':0' for a random port.
    skein_client : skein.Client, optional
        The ``skein.Client`` to use. If not provided, one will be started.
    asynchronous : bool, optional
        If true, starts the cluster in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.

    Examples
    --------
    >>> cluster = YarnCluster(environment='my-env.tar.gz', ...)
    >>> cluster.scale(10)
    """
    def __init__(
        self,
        environment=None,
        n_workers=None,
        worker_vcores=None,
        worker_memory=None,
        worker_restarts=None,
        worker_env=None,
        scheduler_vcores=None,
        scheduler_memory=None,
        deploy_mode=None,
        name=None,
        queue=None,
        tags=None,
        user=None,
        host=None,
        port=None,
        dashboard_address=None,
        skein_client=None,
        asynchronous=False,
        loop=None,
    ):
        spec = _make_specification(
            environment=environment,
            n_workers=n_workers,
            worker_vcores=worker_vcores,
            worker_memory=worker_memory,
            worker_restarts=worker_restarts,
            worker_env=worker_env,
            scheduler_vcores=scheduler_vcores,
            scheduler_memory=scheduler_memory,
            deploy_mode=deploy_mode,
            name=name,
            queue=queue,
            tags=tags,
            user=user,
        )
        self._init_common(
            spec=spec,
            host=host,
            port=port,
            dashboard_address=dashboard_address,
            asynchronous=asynchronous,
            loop=loop,
            skein_client=skein_client,
        )

    @classmethod
    def from_specification(cls,
                           spec,
                           skein_client=None,
                           asynchronous=False,
                           loop=None):
        """Start a dask cluster from a skein specification.

        Parameters
        ----------
        spec : skein.ApplicationSpec, dict, or filename
            The application specification to use. Must define at least one
            service: ``'dask.worker'``. If no ``'dask.scheduler'`` service is
            defined, a scheduler will be started locally.
        skein_client : skein.Client, optional
            The ``skein.Client`` to use. If not provided, one will be started.
        asynchronous : bool, optional
            If true, starts the cluster in asynchronous mode, where it can be
            used in other async code.
        loop : IOLoop, optional
            The IOLoop instance to use. Defaults to the current loop in
            asynchronous mode, otherwise a background loop is started.
        """
        self = super(YarnCluster, cls).__new__(cls)
        if isinstance(spec, dict):
            spec = skein.ApplicationSpec.from_dict(spec)
        elif isinstance(spec, str):
            spec = skein.ApplicationSpec.from_file(spec)
        elif not isinstance(spec, skein.ApplicationSpec):
            raise TypeError("spec must be an ApplicationSpec, dict, or path, "
                            "got %r" % type(spec).__name__)
        if "dask.worker" not in spec.services:
            raise ValueError(
                "Provided Skein specification must include a 'dask.worker' service"
            )

        self._init_common(spec=spec,
                          asynchronous=asynchronous,
                          loop=loop,
                          skein_client=skein_client)
        return self

    @classmethod
    def from_current(cls, asynchronous=False, loop=None):
        """Connect to an existing ``YarnCluster`` from inside the cluster.

        Parameters
        ----------
        asynchronous : bool, optional
            If true, starts the cluster in asynchronous mode, where it can be
            used in other async code.
        loop : IOLoop, optional
            The IOLoop instance to use. Defaults to the current loop in
            asynchronous mode, otherwise a background loop is started.

        Returns
        -------
        YarnCluster
        """
        self = super(YarnCluster, cls).__new__(cls)
        app_id = os.environ.get("DASK_APPLICATION_ID", None)
        app_address = os.environ.get("DASK_APPMASTER_ADDRESS", None)
        security_dir = os.environ.get("DASK_SECURITY_CREDENTIALS", None)
        if app_id is not None and app_address is not None:
            security = (None if security_dir is None else
                        skein.Security.from_directory(security_dir))
            app = skein.ApplicationClient(app_address,
                                          app_id,
                                          security=security)
        else:
            app = skein.ApplicationClient.from_current()

        self._init_common(application_client=app,
                          asynchronous=asynchronous,
                          loop=loop)
        return self

    @classmethod
    def from_application_id(cls,
                            app_id,
                            skein_client=None,
                            asynchronous=False,
                            loop=None):
        """Connect to an existing ``YarnCluster`` with a given application id.

        Parameters
        ----------
        app_id : str
            The existing cluster's application id.
        skein_client : skein.Client
            The ``skein.Client`` to use. If not provided, one will be started.
        asynchronous : bool, optional
            If true, starts the cluster in asynchronous mode, where it can be
            used in other async code.
        loop : IOLoop, optional
            The IOLoop instance to use. Defaults to the current loop in
            asynchronous mode, otherwise a background loop is started.

        Returns
        -------
        YarnCluster
        """
        self = super(YarnCluster, cls).__new__(cls)
        skein_client = _get_skein_client(skein_client)
        app = skein_client.connect(app_id)

        self._init_common(
            application_client=app,
            asynchronous=asynchronous,
            loop=loop,
            skein_client=skein_client,
        )
        return self

    def _init_common(
        self,
        spec=None,
        application_client=None,
        host=None,
        port=None,
        dashboard_address=None,
        asynchronous=False,
        loop=None,
        skein_client=None,
    ):
        self.spec = spec
        self.application_client = application_client
        self._scheduler_kwargs = _make_scheduler_kwargs(
            host=host,
            port=port,
            dashboard_address=dashboard_address,
        )
        self._scheduler = None
        self.scheduler_info = {}
        self._requested = set()
        self.scheduler_comm = None
        self._watch_worker_status_task = None
        self._start_task = None
        self._stop_task = None
        self._finalizer = None
        self._adaptive = None
        self._adaptive_options = {}
        self._skein_client = skein_client
        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self._loop_runner.start()

        if not self.asynchronous:
            self._sync(self._start_internal())

    def _start_cluster(self):
        """Start the cluster and initialize state"""

        skein_client = _get_skein_client(self._skein_client)

        if "dask.scheduler" not in self.spec.services:
            # deploy_mode == 'local'
            scheduler_address = self._scheduler.address
            for k in ["dashboard", "bokeh"]:
                if k in self._scheduler.services:
                    dashboard_port = self._scheduler.services[k].port
                    dashboard_host = urlparse(scheduler_address).hostname
                    dashboard_address = "http://%s:%d" % (
                        dashboard_host,
                        dashboard_port,
                    )
                    break
            else:
                dashboard_address = None

            with submit_and_handle_failures(skein_client, self.spec) as app:
                app.kv["dask.scheduler"] = scheduler_address.encode()
                if dashboard_address is not None:
                    app.kv["dask.dashboard"] = dashboard_address.encode()
        else:
            # deploy_mode == 'remote'
            with submit_and_handle_failures(skein_client, self.spec) as app:
                scheduler_address = app.kv.wait("dask.scheduler").decode()
                dashboard_address = app.kv.get("dask.dashboard")
                if dashboard_address is not None:
                    dashboard_address = dashboard_address.decode()

        # Ensure application gets cleaned up
        self._finalizer = weakref.finalize(self, app.shutdown)

        self.scheduler_address = scheduler_address
        self._dashboard_address = dashboard_address
        self.application_client = app

    def _connect_existing(self):
        spec = self.application_client.get_specification()
        if "dask.worker" not in spec.services:
            raise ValueError("%r is not a valid dask cluster" % self.app_id)

        scheduler_address = self.application_client.kv.wait(
            "dask.scheduler").decode()
        dashboard_address = self.application_client.kv.get("dask.dashboard")
        if dashboard_address is not None:
            dashboard_address = dashboard_address.decode()

        self.spec = spec
        self.scheduler_address = scheduler_address
        self._dashboard_address = dashboard_address

    async def _start_internal(self):
        if self._start_task is None:
            self._start_task = asyncio.ensure_future(self._start_async())
        try:
            await self._start_task
        except BaseException:
            # On exception, cleanup
            await self._stop_internal()
            raise
        return self

    async def _start_async(self):
        if self.spec is not None:
            # Start a new cluster
            if "dask.scheduler" not in self.spec.services:
                self._scheduler = Scheduler(
                    loop=self.loop,
                    **self._scheduler_kwargs,
                )
                await self._scheduler
            else:
                self._scheduler = None
            await self.loop.run_in_executor(None, self._start_cluster)
        else:
            # Connect to an existing cluster
            await self.loop.run_in_executor(None, self._connect_existing)

        self.scheduler_comm = rpc(self.scheduler_address)
        comm = None
        try:
            comm = await self.scheduler_comm.live_comm()
            await comm.write({"op": "subscribe_worker_status"})
            self.scheduler_info = await comm.read()
            workers = self.scheduler_info.get("workers", {})
            self._requested.update(w["name"] for w in workers.values())
            self._watch_worker_status_task = asyncio.ensure_future(
                self._watch_worker_status(comm))
        except Exception:
            if comm is not None:
                await comm.close()

    async def _stop_internal(self, status="SUCCEEDED", diagnostics=None):
        if self._stop_task is None:
            self._stop_task = asyncio.ensure_future(
                self._stop_async(status=status, diagnostics=diagnostics))
        await self._stop_task

    async def _stop_async(self, status="SUCCEEDED", diagnostics=None):
        if self._start_task is not None:
            if not self._start_task.done():
                # We're still starting, cancel task
                await cancel_task(self._start_task)
            self._start_task = None

        if self._adaptive is not None:
            self._adaptive.stop()

        if self._watch_worker_status_task is not None:
            await cancel_task(self._watch_worker_status_task)
            self._watch_worker_status_task = None

        if self.scheduler_comm is not None:
            self.scheduler_comm.close_rpc()
            self.scheduler_comm = None

        await self.loop.run_in_executor(
            None,
            lambda: self._stop_sync(status=status, diagnostics=diagnostics))

        if self._scheduler is not None:
            await self._scheduler.close()
            self._scheduler = None

    def _stop_sync(self, status="SUCCEEDED", diagnostics=None):
        if self._finalizer is not None and self._finalizer.peek() is not None:
            self.application_client.shutdown(status=status,
                                             diagnostics=diagnostics)
            self._finalizer.detach()  # don't run the finalizer later
        self._finalizer = None

    def __await__(self):
        return self.__aenter__().__await__()

    async def __aenter__(self):
        return await self._start_internal()

    async def __aexit__(self, typ, value, traceback):
        await self._stop_internal()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __repr__(self):
        return "YarnCluster<%s>" % self.app_id

    @property
    def loop(self):
        return self._loop_runner.loop

    @property
    def app_id(self):
        return self.application_client.id

    @property
    def asynchronous(self):
        return self._asynchronous

    def _sync(self, task):
        if self.asynchronous:
            return task
        future = asyncio.run_coroutine_threadsafe(task, self.loop.asyncio_loop)
        try:
            return future.result()
        except BaseException:
            future.cancel()
            raise

    @cached_property
    def dashboard_link(self):
        """Link to the dask dashboard. None if dashboard isn't running"""
        if self._dashboard_address is None:
            return None
        dashboard = urlparse(self._dashboard_address)
        return format_dashboard_link(dashboard.hostname, dashboard.port)

    def shutdown(self, status="SUCCEEDED", diagnostics=None):
        """Shutdown the application.

        Parameters
        ----------
        status : {'SUCCEEDED', 'FAILED', 'KILLED'}, optional
            The yarn application exit status.
        diagnostics : str, optional
            The application exit message, usually used for diagnosing failures.
            Can be seen in the YARN Web UI for completed applications under
            "diagnostics". If not provided, a default will be used.
        """
        if self.asynchronous:
            return self._stop_internal(status=status, diagnostics=diagnostics)
        if self.loop.asyncio_loop.is_running() and not sys.is_finalizing():
            self._sync(
                self._stop_internal(status=status, diagnostics=diagnostics))
        else:
            # Always run this!
            self._stop_sync(status=status, diagnostics=diagnostics)
        self._loop_runner.stop()

    def close(self, **kwargs):
        """Close this cluster. An alias for ``shutdown``.

        See Also
        --------
        shutdown
        """
        return self.shutdown(**kwargs)

    def __del__(self):
        if not hasattr(self, "_loop_runner"):
            return
        if self.asynchronous:
            # No del for async mode
            return
        self.close()

    @property
    def _observed(self):
        return {w["name"] for w in self.scheduler_info["workers"].values()}

    async def _workers(self):
        return await self.loop.run_in_executor(
            None,
            lambda: self.application_client.get_containers(services=
                                                           ["dask.worker"]),
        )

    def workers(self):
        """A list of all currently running worker containers."""
        return self._sync(self._workers())

    async def _logs(self, scheduler=True, workers=True):
        logs = Logs()

        if scheduler:
            slogs = await self.scheduler_comm.logs()
            logs["Scheduler"] = Log("\n".join(line for level, line in slogs))

        if workers:
            d = await self.scheduler_comm.worker_logs(workers=workers)
            for k, v in d.items():
                logs[k] = Log("\n".join(line for level, line in v))

        return logs

    def logs(self, scheduler=True, workers=True):
        """Return logs for the scheduler and/or workers

        Parameters
        ----------
        scheduler : boolean, optional
            Whether or not to collect logs for the scheduler
        workers : boolean or iterable, optional
            A list of worker addresses to select. Defaults to all workers if
            ``True`` or no workers if ``False``

        Returns
        -------
        logs : dict
            A dictionary of name -> logs.
        """
        return self._sync(self._logs(scheduler=scheduler, workers=workers))

    async def _scale_up(self, n):
        if n > len(self._requested):
            containers = await self.loop.run_in_executor(
                None,
                lambda: self.application_client.scale("dask.worker", n),
            )
            self._requested.update(c.id for c in containers)

    async def _scale_down(self, workers):
        self._requested.difference_update(workers)
        await self.scheduler_comm.retire_workers(names=list(workers))

        def _kill_containers():
            for c in workers:
                try:
                    self.application_client.kill_container(c)
                except ValueError:
                    pass

        await self.loop.run_in_executor(
            None,
            _kill_containers,
        )

    async def _scale(self, n):
        if self._adaptive is not None:
            self._adaptive.stop()
        if n >= len(self._requested):
            return await self._scale_up(n)
        else:
            n_to_delete = len(self._requested) - n
            pending = list(self._requested - self._observed)
            running = list(self._observed.difference(pending))
            to_close = pending[:n_to_delete]
            to_close.extend(running[:n_to_delete - len(to_close)])
            await self._scale_down(to_close)
            self._requested.difference_update(to_close)

    def scale(self, n):
        """Scale cluster to n workers.

        Parameters
        ----------
        n : int
            Target number of workers

        Examples
        --------
        >>> cluster.scale(10)  # scale cluster to ten workers
        """
        return self._sync(self._scale(n))

    def adapt(self,
              minimum=0,
              maximum=math.inf,
              interval="1s",
              wait_count=3,
              target_duration="5s",
              **kwargs):
        """Turn on adaptivity

        This scales Dask clusters automatically based on scheduler activity.

        Parameters
        ----------
        minimum : int, optional
            Minimum number of workers. Defaults to ``0``.
        maximum : int, optional
            Maximum number of workers. Defaults to ``inf``.
        interval : timedelta or str, optional
            Time between worker add/remove recommendations.
        wait_count : int, optional
            Number of consecutive times that a worker should be suggested for
            removal before we remove it.
        target_duration : timedelta or str, optional
            Amount of time we want a computation to take. This affects how
            aggressively we scale up.
        **kwargs :
            Additional parameters to pass to
            ``distributed.Scheduler.workers_to_close``.

        Examples
        --------
        >>> cluster.adapt(minimum=0, maximum=10)
        """
        if self._adaptive is not None:
            self._adaptive.stop()
        self._adaptive_options.update(
            minimum=minimum,
            maximum=maximum,
            interval=interval,
            wait_count=wait_count,
            target_duration=target_duration,
            **kwargs,
        )
        self._adaptive = Adaptive(self, **self._adaptive_options)

    async def _watch_worker_status(self, comm):
        # We don't want to hold on to a ref to self, otherwise this will
        # leave a dangling reference and prevent garbage collection.
        ref_self = weakref.ref(self)
        self = None
        try:
            while True:
                try:
                    msgs = await comm.read()
                except OSError:
                    break
                try:
                    self = ref_self()
                    if self is None:
                        break
                    for op, msg in msgs:
                        if op == "add":
                            workers = msg.pop("workers")
                            self.scheduler_info["workers"].update(workers)
                            self.scheduler_info.update(msg)
                            self._requested.update(w["name"]
                                                   for w in workers.values())
                        elif op == "remove":
                            del self.scheduler_info["workers"][msg]
                            self._requested.discard(msg)
                    if hasattr(self, "_status_widget"):
                        self._status_widget.value = self._widget_status()
                finally:
                    self = None
        finally:
            await comm.close()

    def _widget_status(self):
        try:
            workers = self.scheduler_info["workers"]
        except KeyError:
            return None
        else:
            n_workers = len(workers)
            cores = sum(w["nthreads"] for w in workers.values())
            memory = sum(w["memory_limit"] for w in workers.values())

            return _widget_status_template % (n_workers, cores,
                                              format_bytes(memory))

    def _widget(self):
        """ Create IPython widget for display within a notebook """
        try:
            return self._cached_widget
        except AttributeError:
            pass

        if self.asynchronous:
            return None

        try:
            from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion
        except ImportError:
            self._cached_widget = None
            return None

        layout = Layout(width="150px")

        title = HTML("<h2>YarnCluster</h2>")

        status = HTML(self._widget_status(), layout=Layout(min_width="150px"))

        request = IntText(0, description="Workers", layout=layout)
        scale = Button(description="Scale", layout=layout)

        minimum = IntText(0, description="Minimum", layout=layout)
        maximum = IntText(0, description="Maximum", layout=layout)
        adapt = Button(description="Adapt", layout=layout)

        accordion = Accordion(
            [HBox([request, scale]),
             HBox([minimum, maximum, adapt])],
            layout=Layout(min_width="500px"),
        )
        accordion.selected_index = None
        accordion.set_title(0, "Manual Scaling")
        accordion.set_title(1, "Adaptive Scaling")

        @adapt.on_click
        def adapt_cb(b):
            self.adapt(minimum=minimum.value, maximum=maximum.value)

        @scale.on_click
        def scale_cb(b):
            with log_errors():
                self.scale(request.value)

        app_id = HTML("<p><b>Application ID: </b>{0}</p>".format(self.app_id))

        elements = [title, HBox([status, accordion]), app_id]

        if self.dashboard_link is not None:
            link = HTML(
                '<p><b>Dashboard: </b><a href="{0}" target="_blank">{0}'
                "</a></p>\n".format(self.dashboard_link))
            elements.append(link)

        self._cached_widget = box = VBox(elements)
        self._status_widget = status

        return box

    def _ipython_display_(self, **kwargs):
        widget = self._widget()
        if widget is not None:
            return widget._ipython_display_(**kwargs)
        else:
            from IPython.display import display

            data = {"text/plain": repr(self), "text/html": self._repr_html_()}
            display(data, raw=True)

    def _repr_html_(self):
        if self.dashboard_link is not None:
            dashboard = "<a href='{0}' target='_blank'>{0}</a>".format(
                self.dashboard_link)
        else:
            dashboard = "Not Available"
        return (
            "<div style='background-color: #f2f2f2; display: inline-block; "
            "padding: 10px; border: 1px solid #999999;'>\n"
            "  <h3>YarnCluster</h3>\n"
            "  <ul>\n"
            "    <li><b>Application ID: </b>{app_id}\n"
            "    <li><b>Dashboard: </b>{dashboard}\n"
            "  </ul>\n"
            "</div>\n").format(app_id=self.app_id, dashboard=dashboard)
Ejemplo n.º 8
0
class Gateway(object):
    """A client for a Dask Gateway Server.

    Parameters
    ----------
    address : str, optional
        The address to the gateway server.
    proxy_address : str, int, optional
        The address of the scheduler proxy server. If an int, it's used as the
        port, with the host/ip taken from ``address``. Provide a full address
        if a different host/ip should be used.
    auth : GatewayAuth, optional
        The authentication method to use.
    asynchronous : bool, optional
        If true, starts the client in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.
    """
    def __init__(self,
                 address=None,
                 proxy_address=None,
                 auth=None,
                 asynchronous=False,
                 loop=None):
        if address is None:
            address = dask.config.get("gateway.address")
        if address is None:
            raise ValueError(
                "No dask-gateway address provided or found in configuration")
        address = address.rstrip("/")

        if proxy_address is None:
            proxy_address = dask.config.get("gateway.proxy-address")
        if proxy_address is None:
            raise ValueError(
                "No dask-gateway proxy address provided or found in configuration"
            )
        if isinstance(proxy_address, int):
            parsed = urlparse(address)
            proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address)
        elif isinstance(proxy_address, str):
            parsed = urlparse(proxy_address)
            proxy_netloc = parsed.netloc if parsed.netloc else proxy_address
        proxy_address = "gateway://%s" % proxy_netloc

        self.address = address
        self.proxy_address = proxy_address

        self._auth = get_auth(auth)
        self._cookie_jar = CookieJar()

        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        self._loop_runner.start()
        if self._asynchronous:
            self._started = self._start()
        else:
            self.sync(self._start)

    async def _start(self):
        self._http_client = AsyncHTTPClient()

    def close(self):
        """Close this gateway client"""
        if not self.asynchronous:
            self._loop_runner.stop()

    def __del__(self):
        # __del__ is still called, even if __init__ failed. Only close if init
        # actually succeeded.
        if hasattr(self, "_started"):
            self.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __await__(self):
        return self._started.__await__()

    async def __aenter__(self):
        await self._started
        return self

    async def __aexit__(self, typ, value, traceback):
        pass

    def __repr__(self):
        return "Gateway<%s>" % self.address

    @property
    def asynchronous(self):
        return (self._asynchronous
                or getattr(thread_state, "asynchronous", False)
                or (hasattr(self.loop, "_thread_identity")
                    and self.loop._thread_identity == get_ident()))

    def sync(self, func, *args, **kwargs):
        if kwargs.pop("asynchronous", None) or self.asynchronous:
            callback_timeout = kwargs.pop("callback_timeout", None)
            future = func(*args, **kwargs)
            if callback_timeout is not None:
                future = gen.with_timeout(timedelta(seconds=callback_timeout),
                                          future)
            return future
        else:
            return sync(self.loop, func, *args, **kwargs)

    async def _fetch(self, req):
        try:
            self._cookie_jar.pre_request(req)
            resp = await self._http_client.fetch(req, raise_error=False)
            if resp.code == 401:
                context = self._auth.pre_request(req, resp)
                resp = await self._http_client.fetch(req, raise_error=False)
                self._auth.post_response(req, resp, context)
            self._cookie_jar.post_response(resp)
            if resp.error:
                if resp.code == 599:
                    raise TimeoutError("Request timed out")
                else:
                    try:
                        msg = json.loads(resp.body)["error"]
                    except Exception:
                        msg = resp.body.decode()

                    if resp.code in {404, 422}:
                        raise ValueError(msg)
                    elif resp.code == 409:
                        raise GatewayClusterError(msg)
                    elif resp.code == 500:
                        raise GatewayServerError(msg)
                    else:
                        resp.rethrow()
        except HTTPError as exc:
            # Tornado 6 still raises these above with raise_error=False
            if exc.code == 599:
                raise TimeoutError("Request timed out")
            # Should never get here!
            raise
        return resp

    async def _clusters(self, status=None):
        if status is not None:
            if isinstance(status, (str, ClusterStatus)):
                status = [ClusterStatus._create(status)]
            else:
                status = [ClusterStatus._create(s) for s in status]
            query = "?status=" + ",".join(s.name for s in status)
        else:
            query = ""

        url = "%s/gateway/api/clusters/%s" % (self.address, query)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return [
            ClusterReport._from_json(self.address, self.proxy_address, r)
            for r in json.loads(resp.body).values()
        ]

    def list_clusters(self, status=None, **kwargs):
        """List clusters for this user.

        Parameters
        ----------
        status : ClusterStatus, str, or list, optional
            The cluster status (or statuses) to select. Valid options are
            'starting', 'started', 'running', 'stopping', 'stopped', 'failed'.
            By default selects active clusters ('starting', 'started',
            'running').

        Returns
        -------
        clusters : list of ClusterReport
        """
        return self.sync(self._clusters, status=status, **kwargs)

    async def _cluster_options(self):
        url = "%s/gateway/api/clusters/options" % self.address
        req = HTTPRequest(url=url, method="GET")
        resp = await self._fetch(req)
        data = json.loads(resp.body)
        return Options._from_spec(data["cluster_options"])

    def cluster_options(self, **kwargs):
        """Get the available cluster configuration options.

        Returns
        -------
        cluster_options : Options
            A dict of cluster options.
        """
        return self.sync(self._cluster_options, **kwargs)

    async def _submit(self, cluster_options=None, **kwargs):
        url = "%s/gateway/api/clusters/" % self.address
        if cluster_options is not None:
            if not isinstance(cluster_options, Options):
                raise TypeError(
                    "cluster_options must be an `Options`, got %r" %
                    type(cluster_options).__name__)
            options = dict(cluster_options)
            options.update(kwargs)
        else:
            options = kwargs
        req = HTTPRequest(
            url=url,
            method="POST",
            body=json.dumps({"cluster_options": options}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        resp = await self._fetch(req)
        data = json.loads(resp.body)
        return data["name"]

    def submit(self, cluster_options=None, **kwargs):
        """Submit a new cluster to be started.

        This returns quickly with a ``cluster_name``, which can later be used
        to connect to the cluster.

        Parameters
        ----------
        cluster_options : Options, optional
            An ``Options`` object describing the desired cluster configuration.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._submit, **kwargs)

    async def _cluster_report(self, cluster_name, wait=False):
        params = "?wait" if wait else ""
        url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name,
                                                params)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return ClusterReport._from_json(self.address, self.proxy_address,
                                        json.loads(resp.body))

    async def _connect(self, cluster_name):
        while True:
            try:
                report = await self._cluster_report(cluster_name, wait=True)
            except TimeoutError:
                # Timeout, ignore
                pass
            else:
                if report.status is ClusterStatus.RUNNING:
                    return GatewayCluster(
                        gateway=self,
                        name=report.name,
                        scheduler_address=report.scheduler_address,
                        dashboard_link=report.dashboard_link,
                        security=report.security,
                    )
                elif report.status is ClusterStatus.FAILED:
                    raise GatewayClusterError(
                        "Cluster %r failed to start, see logs for "
                        "more information" % cluster_name)
                elif report.status is ClusterStatus.STOPPED:
                    raise GatewayClusterError("Cluster %r is already stopped" %
                                              cluster_name)
            # Not started yet, try again later
            await gen.sleep(0.5)

    def connect(self, cluster_name, **kwargs):
        """Connect to a submitted cluster.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._connect, cluster_name, **kwargs)

    async def _new_cluster(self, **kwargs):
        cluster_name = await self._submit(**kwargs)
        try:
            return await self._connect(cluster_name)
        except GatewayClusterError:
            raise
        except BaseException:
            # Ensure cluster is stopped on error
            await self._stop_cluster(cluster_name)
            raise

    def new_cluster(self, cluster_options=None, **kwargs):
        """Submit a new cluster to the gateway, and wait for it to be started.

        Same as calling ``submit`` and ``connect`` in one go.

        Parameters
        ----------
        cluster_options : Options, optional
            An ``Options`` object describing the desired cluster configuration.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._new_cluster,
                         cluster_options=cluster_options,
                         **kwargs)

    async def _stop_cluster(self, cluster_name):
        url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name)
        req = HTTPRequest(url=url, method="DELETE")
        await self._fetch(req)

    def stop_cluster(self, cluster_name, **kwargs):
        """Stop a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._stop_cluster, cluster_name, **kwargs)

    async def _scale_cluster(self, cluster_name, n):
        url = "%s/gateway/api/clusters/%s/workers" % (self.address,
                                                      cluster_name)
        req = HTTPRequest(
            url=url,
            method="PUT",
            body=json.dumps({"worker_count": n}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        await self._fetch(req)

    def scale_cluster(self, cluster_name, n, **kwargs):
        """Scale a cluster to n workers.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        n : int
            The number of workers to scale to.
        """
        return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
Ejemplo n.º 9
0
class Gateway(object):
    """A client for a Dask Gateway Server.

    Parameters
    ----------
    address : str, optional
        The address to the gateway server.
    auth : GatewayAuth, optional
        The authentication method to use.
    asynchronous : bool, optional
        If true, starts the client in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.
    """
    def __init__(self, address=None, auth=None, asynchronous=False, loop=None):
        if address is None:
            address = dask.config.get("gateway.address")
        if address is None:
            raise ValueError(
                "No dask-gateway address provided or found in configuration")
        self.address = address.rstrip("/")
        self._auth = get_auth(auth)
        self._cookie_jar = CookieJar()

        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        self._loop_runner.start()
        if self._asynchronous:
            self._started = self._start()
        else:
            self.sync(self._start)

    async def _start(self):
        self._http_client = AsyncHTTPClient()

    def close(self):
        """Close this gateway client"""
        if not self.asynchronous:
            self._loop_runner.stop()

    def __del__(self):
        # __del__ is still called, even if __init__ failed. Only close if init
        # actually succeeded.
        if hasattr(self, "_started"):
            self.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __await__(self):
        return self._started.__await__()

    async def __aenter__(self):
        await self._started
        return self

    async def __aexit__(self, typ, value, traceback):
        pass

    def __repr__(self):
        return "Gateway<%s>" % self.address

    @property
    def asynchronous(self):
        return (self._asynchronous
                or getattr(thread_state, "asynchronous", False)
                or (hasattr(self.loop, "_thread_identity")
                    and self.loop._thread_identity == get_ident()))

    def sync(self, func, *args, **kwargs):
        if kwargs.pop("asynchronous", None) or self.asynchronous:
            callback_timeout = kwargs.pop("callback_timeout", None)
            future = func(*args, **kwargs)
            if callback_timeout is not None:
                future = gen.with_timeout(timedelta(seconds=callback_timeout),
                                          future)
            return future
        else:
            return sync(self.loop, func, *args, **kwargs)

    async def _fetch(self, req, raise_error=True):
        self._cookie_jar.pre_request(req)
        resp = await self._http_client.fetch(req, raise_error=False)
        if resp.code == 401:
            context = self._auth.pre_request(req, resp)
            resp = await self._http_client.fetch(req, raise_error=False)
            self._auth.post_response(req, resp, context)
        self._cookie_jar.post_response(resp)
        if raise_error:
            resp.rethrow()
        return resp

    async def _clusters(self, status=None):
        if status is not None:
            if isinstance(status, (str, ClusterStatus)):
                status = [ClusterStatus._create(status)]
            else:
                status = [ClusterStatus._create(s) for s in status]
            query = "?status=" + ",".join(s.name for s in status)
        else:
            query = ""

        url = "%s/gateway/api/clusters/%s" % (self.address, query)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return [
            ClusterReport._from_json(self.address, r)
            for r in json.loads(resp.body).values()
        ]

    def list_clusters(self, status=None, **kwargs):
        """List clusters for this user.

        Parameters
        ----------
        status : ClusterStatus, str, or list, optional
            The cluster status (or statuses) to select. Valid options are
            'starting', 'started', 'running', 'stopping', 'stopped', 'failed'.
            By default selects active clusters ('starting', 'started',
            'running').

        Returns
        -------
        clusters : list of ClusterReport
        """
        return self.sync(self._clusters, status=status, **kwargs)

    async def _submit(self):
        url = "%s/gateway/api/clusters/" % self.address
        req = HTTPRequest(
            url=url,
            method="POST",
            body=json.dumps({}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        resp = await self._fetch(req)
        data = json.loads(resp.body)
        return data["name"]

    def submit(self, **kwargs):
        """Submit a new cluster to be started.

        This returns quickly with a ``cluster_name``, which can later be used
        to connect to the cluster.

        Returns
        -------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._submit, **kwargs)

    async def _cluster_report(self, cluster_name, wait=False):
        params = "?wait" if wait else ""
        url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name,
                                                params)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return ClusterReport._from_json(self.address, json.loads(resp.body))

    async def _connect(self, cluster_name):
        while True:
            try:
                report = await self._cluster_report(cluster_name, wait=True)
            except HTTPError as exc:
                if exc.code == 404:
                    raise Exception("Unknown cluster %r" % cluster_name)
                elif exc.code == 599:
                    # Timeout, ignore
                    pass
                else:
                    raise
            else:
                if report.status is ClusterStatus.RUNNING:
                    return GatewayCluster(self, report)
                elif report.status is ClusterStatus.FAILED:
                    raise Exception("Cluster %r failed to start, see logs for "
                                    "more information" % cluster_name)
                elif report.status is ClusterStatus.STOPPED:
                    raise Exception("Cluster %r is already stopped" %
                                    cluster_name)
            # Not started yet, try again later
            await gen.sleep(0.5)

    def connect(self, cluster_name, **kwargs):
        """Connect to a submitted cluster.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._connect, cluster_name, **kwargs)

    async def _new_cluster(self, **kwargs):
        cluster_name = await self._submit(**kwargs)
        try:
            return await self._connect(cluster_name)
        except BaseException:
            # Ensure cluster is stopped on error
            await self._stop_cluster(cluster_name)
            raise

    def new_cluster(self, **kwargs):
        """Submit a new cluster to the gateway, and wait for it to be started.

        Same as calling ``submit`` and ``connect`` in one go.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._new_cluster, **kwargs)

    async def _stop_cluster(self, cluster_name):
        url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name)
        req = HTTPRequest(url=url, method="DELETE")
        await self._fetch(req)

    def stop_cluster(self, cluster_name, **kwargs):
        """Stop a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._stop_cluster, cluster_name, **kwargs)

    async def _scale_cluster(self, cluster_name, n):
        url = "%s/gateway/api/clusters/%s/workers" % (self.address,
                                                      cluster_name)
        req = HTTPRequest(
            url=url,
            method="PUT",
            body=json.dumps({"worker_count": n}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        try:
            await self._fetch(req)
        except HTTPError as exc:
            if exc.code == 409:
                raise Exception("Cluster %r is not running" % cluster_name)
            raise

    def scale_cluster(self, cluster_name, n, **kwargs):
        """Scale a cluster to n workers.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        n : int
            The number of workers to scale to.
        """
        return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
Ejemplo n.º 10
0
class Gateway(object):
    """A client for a Dask Gateway Server.

    Parameters
    ----------
    address : str, optional
        The address to the gateway server.
    proxy_address : str, int, optional
        The address of the scheduler proxy server. Defaults to `address` if not
        provided. If an int, it's used as the port, with the host/ip taken from
        ``address``. Provide a full address if a different host/ip should be
        used.
    auth : GatewayAuth, optional
        The authentication method to use.
    asynchronous : bool, optional
        If true, starts the client in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.
    """
    def __init__(self,
                 address=None,
                 proxy_address=None,
                 auth=None,
                 asynchronous=False,
                 loop=None):
        if address is None:
            address = format_template(dask.config.get("gateway.address"))
        if address is None:
            raise ValueError(
                "No dask-gateway address provided or found in configuration")
        address = address.rstrip("/")

        public_address = format_template(
            dask.config.get("gateway.public-address"))
        if public_address is None:
            public_address = address
        else:
            public_address = public_address.rstrip("/")

        if proxy_address is None:
            proxy_address = format_template(
                dask.config.get("gateway.proxy-address"))
        if proxy_address is None:
            parsed = urlparse(address)
            if parsed.netloc:
                if parsed.port is None:
                    proxy_port = {
                        "http": 80,
                        "https": 443
                    }.get(parsed.scheme, 8786)
                else:
                    proxy_port = parsed.port
                proxy_netloc = "%s:%d" % (parsed.hostname, proxy_port)
            else:
                proxy_netloc = proxy_address
        elif isinstance(proxy_address, int):
            parsed = urlparse(address)
            proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address)
        elif isinstance(proxy_address, str):
            parsed = urlparse(proxy_address)
            proxy_netloc = parsed.netloc if parsed.netloc else proxy_address
        proxy_address = "gateway://%s" % proxy_netloc

        scheme = urlparse(address).scheme
        self._request_kwargs = _get_default_request_kwargs(scheme)

        self.address = address
        self._public_address = public_address
        self.proxy_address = proxy_address

        self.auth = get_auth(auth)
        self._session = None

        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self._loop_runner.start()

    @property
    def loop(self):
        return self._loop_runner.loop

    @property
    def asynchronous(self):
        return self._asynchronous

    def sync(self, func, *args, **kwargs):
        if self.asynchronous:
            return func(*args, **kwargs)
        else:
            future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs),
                                                      self.loop.asyncio_loop)
            try:
                return future.result()
            except BaseException:
                future.cancel()
                raise

    def close(self):
        """Close the gateway client"""
        if self.asynchronous:
            return self._cleanup()
        elif self.loop.asyncio_loop.is_running():
            self.sync(self._cleanup)
        self._loop_runner.stop()

    async def _cleanup(self):
        if self._session is not None:
            await self._session.close()
            self._session = None

    async def __aenter__(self):
        return self

    async def __aexit__(self, typ, value, traceback):
        await self._cleanup()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __del__(self):
        if (not self.asynchronous and hasattr(self, "_loop_runner")
                and not sys.is_finalizing()):
            self.close()

    def __repr__(self):
        return "Gateway<%s>" % self.address

    async def _request(self, method, url, json=None):
        if self._session is None:
            self._session = aiohttp.ClientSession()
        session = self._session

        resp = await session.request(method,
                                     url,
                                     json=json,
                                     **self._request_kwargs)

        if resp.status == 401:
            headers, context = self.auth.pre_request(resp)
            resp = await session.request(method,
                                         url,
                                         json=json,
                                         headers=headers,
                                         **self._request_kwargs)
            self.auth.post_response(resp, context)

        if resp.status >= 400:
            try:
                msg = await resp.json()
                msg = msg["error"]
            except Exception:
                msg = await resp.text()

            if resp.status in {404, 422}:
                raise ValueError(msg)
            elif resp.status == 409:
                raise GatewayClusterError(msg)
            elif resp.status == 500:
                raise GatewayServerError(msg)
            else:
                resp.raise_for_status()
        else:
            return resp

    async def _clusters(self, status=None):
        if status is not None:
            if isinstance(status, (str, ClusterStatus)):
                status = [ClusterStatus._create(status)]
            else:
                status = [ClusterStatus._create(s) for s in status]
            query = "?status=" + ",".join(s.name for s in status)
        else:
            query = ""

        url = "%s/api/v1/clusters/%s" % (self.address, query)
        resp = await self._request("GET", url)
        data = await resp.json()
        return [
            ClusterReport._from_json(self._public_address, self.proxy_address,
                                     r) for r in data.values()
        ]

    def list_clusters(self, status=None, **kwargs):
        """List clusters for this user.

        Parameters
        ----------
        status : ClusterStatus, str, or list, optional
            The cluster status (or statuses) to select. Valid options are
            'pending', 'running', 'stopping', 'stopped', 'failed'.
            By default selects active clusters ('pending', 'running').

        Returns
        -------
        clusters : list of ClusterReport
        """
        return self.sync(self._clusters, status=status, **kwargs)

    def get_cluster(self, cluster_name, **kwargs):
        """Get information about a specific cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.

        Returns
        -------
        report : ClusterReport
        """
        return self.sync(self._cluster_report, cluster_name, **kwargs)

    async def _get_versions(self):
        url = "%s/api/version" % self.address
        resp = await self._request("GET", url)
        server_info = await resp.json()
        from . import __version__

        return {
            "server": server_info,
            "client": {
                "version": __version__
            },
        }

    def get_versions(self):
        """Return version info for the server and client

        Returns
        -------
        version_info : dict
        """
        return self.sync(self._get_versions)

    def _config_cluster_options(self):
        opts = dask.config.get("gateway.cluster.options")
        return {k: format_template(v) for k, v in opts.items()}

    async def _cluster_options(self, use_local_defaults=True):
        url = "%s/api/v1/options" % self.address
        resp = await self._request("GET", url)
        data = await resp.json()
        options = Options._from_spec(data["cluster_options"])
        if use_local_defaults:
            options.update(self._config_cluster_options())
        return options

    def cluster_options(self, use_local_defaults=True, **kwargs):
        """Get the available cluster configuration options.

        Parameters
        ----------
        use_local_defaults : bool, optional
            Whether to use any default options from the local configuration.
            Default is True, set to False to use only the server-side defaults.

        Returns
        -------
        cluster_options : dask_gateway.options.Options
            A dict of cluster options.
        """
        return self.sync(self._cluster_options,
                         use_local_defaults=use_local_defaults,
                         **kwargs)

    async def _submit(self, cluster_options=None, **kwargs):
        url = "%s/api/v1/clusters/" % self.address
        if cluster_options is not None:
            if not isinstance(cluster_options, Options):
                raise TypeError(
                    "cluster_options must be an `Options`, got %r" %
                    type(cluster_options).__name__)
            options = dict(cluster_options)
            options.update(kwargs)
        else:
            options = self._config_cluster_options()
            options.update(kwargs)
        resp = await self._request("POST",
                                   url,
                                   json={"cluster_options": options})
        data = await resp.json()
        return data["name"]

    def submit(self, cluster_options=None, **kwargs):
        """Submit a new cluster to be started.

        This returns quickly with a ``cluster_name``, which can later be used
        to connect to the cluster.

        Parameters
        ----------
        cluster_options : dask_gateway.options.Options, optional
            An ``Options`` object describing the desired cluster configuration.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._submit,
                         cluster_options=cluster_options,
                         **kwargs)

    async def _cluster_report(self, cluster_name, wait=False):
        params = "?wait" if wait else ""
        url = "%s/api/v1/clusters/%s%s" % (self.address, cluster_name, params)
        resp = await self._request("GET", url)
        data = await resp.json()
        return ClusterReport._from_json(self._public_address,
                                        self.proxy_address, data)

    async def _wait_for_start(self, cluster_name):
        while True:
            try:
                report = await self._cluster_report(cluster_name, wait=True)
            except TimeoutError:
                # Timeout, ignore
                pass
            else:
                if report.status is ClusterStatus.RUNNING:
                    return report
                elif report.status is ClusterStatus.FAILED:
                    raise GatewayClusterError(
                        "Cluster %r failed to start, see logs for "
                        "more information" % cluster_name)
                elif report.status is ClusterStatus.STOPPED:
                    raise GatewayClusterError("Cluster %r is already stopped" %
                                              cluster_name)
            # Not started yet, try again later
            await asyncio.sleep(0.5)

    def connect(self, cluster_name, shutdown_on_close=False):
        """Connect to a submitted cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster to connect to.
        shutdown_on_close : bool, optional
            If True, the cluster will be automatically shutdown on close.
            Default is False.

        Returns
        -------
        cluster : GatewayCluster
        """
        return GatewayCluster.from_name(
            cluster_name,
            shutdown_on_close=shutdown_on_close,
            address=self.address,
            proxy_address=self.proxy_address,
            auth=self.auth,
            asynchronous=self.asynchronous,
            loop=self.loop,
        )

    def new_cluster(self,
                    cluster_options=None,
                    shutdown_on_close=True,
                    **kwargs):
        """Submit a new cluster to the gateway, and wait for it to be started.

        Same as calling ``submit`` and ``connect`` in one go.

        Parameters
        ----------
        cluster_options : dask_gateway.options.Options, optional
            An ``Options`` object describing the desired cluster configuration.
        shutdown_on_close : bool, optional
            If True (default), the cluster will be automatically shutdown on
            close. Set to False to have cluster persist until explicitly
            shutdown.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster : GatewayCluster
        """
        return GatewayCluster(
            address=self.address,
            proxy_address=self.proxy_address,
            auth=self.auth,
            asynchronous=self.asynchronous,
            loop=self.loop,
            cluster_options=cluster_options,
            shutdown_on_close=shutdown_on_close,
            **kwargs,
        )

    async def _stop_cluster(self, cluster_name):
        url = "%s/api/v1/clusters/%s" % (self.address, cluster_name)
        await self._request("DELETE", url)

    def stop_cluster(self, cluster_name, **kwargs):
        """Stop a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._stop_cluster, cluster_name, **kwargs)

    async def _scale_cluster(self, cluster_name, n):
        url = "%s/api/v1/clusters/%s/scale" % (self.address, cluster_name)
        await self._request("POST", url, json={"count": n})

    def scale_cluster(self, cluster_name, n, **kwargs):
        """Scale a cluster to n workers.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        n : int
            The number of workers to scale to.
        """
        return self.sync(self._scale_cluster, cluster_name, n, **kwargs)

    async def _adapt_cluster(self,
                             cluster_name,
                             minimum=None,
                             maximum=None,
                             active=True):
        await self._request(
            "POST",
            "%s/api/v1/clusters/%s/adapt" % (self.address, cluster_name),
            json={
                "minimum": minimum,
                "maximum": maximum,
                "active": active
            },
        )

    def adapt_cluster(self,
                      cluster_name,
                      minimum=None,
                      maximum=None,
                      active=True,
                      **kwargs):
        """Configure adaptive scaling for a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        minimum : int, optional
            The minimum number of workers to scale to. Defaults to 0.
        maximum : int, optional
            The maximum number of workers to scale to. Defaults to infinity.
        active : bool, optional
            If ``True`` (default), adaptive scaling is activated. Set to
            ``False`` to deactivate adaptive scaling.
        """
        return self.sync(
            self._adapt_cluster,
            cluster_name,
            minimum=minimum,
            maximum=maximum,
            active=active,
            **kwargs,
        )
Ejemplo n.º 11
0
def test_two_loop_runners(loop_in_thread):
    # Loop runners tied to the same loop should cooperate

    # ABCCBA
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    b.start()
    assert_running(loop)
    c.start()
    assert_running(loop)
    c.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    a.stop()
    assert_not_running(loop)

    # ABCABC
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    b.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    c.start()
    assert_running(loop)
    a.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    c.stop()
    assert_not_running(loop)

    # Explicit loop, already started
    a = LoopRunner(loop=loop_in_thread)
    b = LoopRunner(loop=loop_in_thread)
    assert_running(loop_in_thread)
    a.start()
    assert_running(loop_in_thread)
    b.start()
    assert_running(loop_in_thread)
    a.stop()
    assert_running(loop_in_thread)
    b.stop()
    assert_running(loop_in_thread)
Ejemplo n.º 12
0
def test_loop_runner(loop_in_thread):
    # Implicit loop
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner()
    assert runner.loop not in (loop, loop_in_thread)
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop
    loop = IOLoop()
    runner = LoopRunner(loop=loop)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(loop)
    runner.start()
    assert runner.is_started()
    assert_running(loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(loop)

    # Explicit loop, already started
    runner = LoopRunner(loop=loop_in_thread)
    assert not runner.is_started()
    assert_running(loop_in_thread)
    runner.start()
    assert runner.is_started()
    assert_running(loop_in_thread)
    runner.stop()
    assert not runner.is_started()
    assert_running(loop_in_thread)

    # Implicit loop, asynchronous=True
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop, asynchronous=True
    loop = IOLoop()
    runner = LoopRunner(loop=loop, asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)