Ejemplo n.º 1
0
def test_close_loop_sync(with_own_loop):
    loop_runner = loop = None

    # Setup simple cluster with one threaded worker.
    # Complex setup is not required here since we test only IO loop teardown.
    cluster_params = dict(n_workers=1, dashboard_address=None, processes=False)

    loops_before = LoopRunner._all_loops.copy()

    # Start own loop or use current thread's one.
    if with_own_loop:
        loop_runner = LoopRunner()
        loop_runner.start()
        loop = loop_runner.loop

    with LocalCluster(loop=loop, **cluster_params) as cluster:
        with Client(cluster, loop=loop) as client:
            client.run(max, 1, 2)

    # own loop must be explicitly stopped.
    if with_own_loop:
        loop_runner.stop()

    # Internal loops registry must the same as before cluster running.
    # This means loop runners in LocalCluster and Client correctly stopped.
    # See LoopRunner._stop_unlocked().
    assert loops_before == LoopRunner._all_loops
Ejemplo n.º 2
0
async def test_loop_runner_gen():
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is IOLoop.current()
    assert not runner.is_started()
    await asyncio.sleep(0.01)
    runner.start()
    assert runner.is_started()
    await asyncio.sleep(0.01)
    runner.stop()
    assert not runner.is_started()
    await asyncio.sleep(0.01)
Ejemplo n.º 3
0
def test_loop_runner_gen():
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is IOLoop.current()
    assert not runner.is_started()
    yield gen.sleep(0.01)
    runner.start()
    assert runner.is_started()
    yield gen.sleep(0.01)
    runner.stop()
    assert not runner.is_started()
    yield gen.sleep(0.01)
Ejemplo n.º 4
0
class ContextCluster(Cluster):
    def __init__(self, asynchronous=False, loop=None):
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        super().__init__(asynchronous=asynchronous)

        if not self.asynchronous:
            self._loop_runner.start()
            self.sync(self._start)

    def __enter__(self):
        if self.status != "running":
            raise ValueError(
                f"Expected status 'running', found '{self.status}'")
        return self

    def __exit__(self, typ, value, traceback):
        self.close()
        self._loop_runner.stop()
Ejemplo n.º 5
0
def test_two_loop_runners(loop_in_thread):
    # Loop runners tied to the same loop should cooperate

    # ABCCBA
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    b.start()
    assert_running(loop)
    c.start()
    assert_running(loop)
    c.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    a.stop()
    assert_not_running(loop)

    # ABCABC
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    b.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    c.start()
    assert_running(loop)
    a.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    c.stop()
    assert_not_running(loop)

    # Explicit loop, already started
    a = LoopRunner(loop=loop_in_thread)
    b = LoopRunner(loop=loop_in_thread)
    assert_running(loop_in_thread)
    a.start()
    assert_running(loop_in_thread)
    b.start()
    assert_running(loop_in_thread)
    a.stop()
    assert_running(loop_in_thread)
    b.stop()
    assert_running(loop_in_thread)
Ejemplo n.º 6
0
def test_loop_runner(loop_in_thread):
    # Implicit loop
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner()
    assert runner.loop not in (loop, loop_in_thread)
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop
    loop = IOLoop()
    runner = LoopRunner(loop=loop)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(loop)
    runner.start()
    assert runner.is_started()
    assert_running(loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(loop)

    # Explicit loop, already started
    runner = LoopRunner(loop=loop_in_thread)
    assert not runner.is_started()
    assert_running(loop_in_thread)
    runner.start()
    assert runner.is_started()
    assert_running(loop_in_thread)
    runner.stop()
    assert not runner.is_started()
    assert_running(loop_in_thread)

    # Implicit loop, asynchronous=True
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop, asynchronous=True
    loop = IOLoop()
    runner = LoopRunner(loop=loop, asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)
Ejemplo n.º 7
0
class YarnCluster(object):
    """Start a Dask cluster on YARN.

    You can define default values for this in Dask's ``yarn.yaml``
    configuration file. See http://docs.dask.org/en/latest/configuration.html
    for more information.

    Parameters
    ----------
    environment : str, optional
        The Python environment to use. Can be one of the following:

          - A path to an archived Python environment
          - A path to a conda environment, specified as `conda:///...`
          - A path to a virtual environment, specified as `venv:///...`
          - A path to a python executable, specifed as `python:///...`

        Note that if not an archive, the paths specified must be valid on all
        nodes in the cluster.
    n_workers : int, optional
        The number of workers to initially start.
    worker_vcores : int, optional
        The number of virtual cores to allocate per worker.
    worker_memory : str, optional
        The amount of memory to allocate per worker. Accepts a unit suffix
        (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest MiB.
    worker_restarts : int, optional
        The maximum number of worker restarts to allow before failing the
        application. Default is unlimited.
    worker_env : dict, optional
        A mapping of environment variables to their values. These will be set
        in the worker containers before starting the dask workers.
    scheduler_vcores : int, optional
        The number of virtual cores to allocate per scheduler.
    scheduler_memory : str, optional
        The amount of memory to allocate to the scheduler. Accepts a unit
        suffix (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest
        MiB.
    deploy_mode : {'remote', 'local'}, optional
        The deploy mode to use. If ``'remote'``, the scheduler will be deployed
        in a YARN container. If ``'local'``, the scheduler will run locally,
        which can be nice for debugging. Default is ``'remote'``.
    name : str, optional
        The application name.
    queue : str, optional
        The queue to deploy to.
    tags : sequence, optional
        A set of strings to use as tags for this application.
    user : str, optional
        The user to submit the application on behalf of. Default is the current
        user - submitting as a different user requires user permissions, see
        the YARN documentation for more information.
    host : str, optional
        Host address on which the scheduler will listen. Only used if
        ``deploy_mode='local'``. Defaults to ``'0.0.0.0'``.
    port : int, optional
        The port on which the scheduler will listen. Only used if
        ``deploy_mode='local'``. Defaults to ``0`` for a random port.
    dashboard_address : str
        Address on which to the dashboard server will listen. Only used if
        ``deploy_mode='local'``. Defaults to ':0' for a random port.
    skein_client : skein.Client, optional
        The ``skein.Client`` to use. If not provided, one will be started.
    asynchronous : bool, optional
        If true, starts the cluster in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.

    Examples
    --------
    >>> cluster = YarnCluster(environment='my-env.tar.gz', ...)
    >>> cluster.scale(10)
    """
    def __init__(
        self,
        environment=None,
        n_workers=None,
        worker_vcores=None,
        worker_memory=None,
        worker_restarts=None,
        worker_env=None,
        scheduler_vcores=None,
        scheduler_memory=None,
        deploy_mode=None,
        name=None,
        queue=None,
        tags=None,
        user=None,
        host=None,
        port=None,
        dashboard_address=None,
        skein_client=None,
        asynchronous=False,
        loop=None,
    ):
        spec = _make_specification(
            environment=environment,
            n_workers=n_workers,
            worker_vcores=worker_vcores,
            worker_memory=worker_memory,
            worker_restarts=worker_restarts,
            worker_env=worker_env,
            scheduler_vcores=scheduler_vcores,
            scheduler_memory=scheduler_memory,
            deploy_mode=deploy_mode,
            name=name,
            queue=queue,
            tags=tags,
            user=user,
        )
        self._init_common(
            spec=spec,
            host=host,
            port=port,
            dashboard_address=dashboard_address,
            asynchronous=asynchronous,
            loop=loop,
            skein_client=skein_client,
        )

    @classmethod
    def from_specification(cls,
                           spec,
                           skein_client=None,
                           asynchronous=False,
                           loop=None):
        """Start a dask cluster from a skein specification.

        Parameters
        ----------
        spec : skein.ApplicationSpec, dict, or filename
            The application specification to use. Must define at least one
            service: ``'dask.worker'``. If no ``'dask.scheduler'`` service is
            defined, a scheduler will be started locally.
        skein_client : skein.Client, optional
            The ``skein.Client`` to use. If not provided, one will be started.
        asynchronous : bool, optional
            If true, starts the cluster in asynchronous mode, where it can be
            used in other async code.
        loop : IOLoop, optional
            The IOLoop instance to use. Defaults to the current loop in
            asynchronous mode, otherwise a background loop is started.
        """
        self = super(YarnCluster, cls).__new__(cls)
        if isinstance(spec, dict):
            spec = skein.ApplicationSpec.from_dict(spec)
        elif isinstance(spec, str):
            spec = skein.ApplicationSpec.from_file(spec)
        elif not isinstance(spec, skein.ApplicationSpec):
            raise TypeError("spec must be an ApplicationSpec, dict, or path, "
                            "got %r" % type(spec).__name__)
        if "dask.worker" not in spec.services:
            raise ValueError(
                "Provided Skein specification must include a 'dask.worker' service"
            )

        self._init_common(spec=spec,
                          asynchronous=asynchronous,
                          loop=loop,
                          skein_client=skein_client)
        return self

    @classmethod
    def from_current(cls, asynchronous=False, loop=None):
        """Connect to an existing ``YarnCluster`` from inside the cluster.

        Parameters
        ----------
        asynchronous : bool, optional
            If true, starts the cluster in asynchronous mode, where it can be
            used in other async code.
        loop : IOLoop, optional
            The IOLoop instance to use. Defaults to the current loop in
            asynchronous mode, otherwise a background loop is started.

        Returns
        -------
        YarnCluster
        """
        self = super(YarnCluster, cls).__new__(cls)
        app_id = os.environ.get("DASK_APPLICATION_ID", None)
        app_address = os.environ.get("DASK_APPMASTER_ADDRESS", None)
        security_dir = os.environ.get("DASK_SECURITY_CREDENTIALS", None)
        if app_id is not None and app_address is not None:
            security = (None if security_dir is None else
                        skein.Security.from_directory(security_dir))
            app = skein.ApplicationClient(app_address,
                                          app_id,
                                          security=security)
        else:
            app = skein.ApplicationClient.from_current()

        self._init_common(application_client=app,
                          asynchronous=asynchronous,
                          loop=loop)
        return self

    @classmethod
    def from_application_id(cls,
                            app_id,
                            skein_client=None,
                            asynchronous=False,
                            loop=None):
        """Connect to an existing ``YarnCluster`` with a given application id.

        Parameters
        ----------
        app_id : str
            The existing cluster's application id.
        skein_client : skein.Client
            The ``skein.Client`` to use. If not provided, one will be started.
        asynchronous : bool, optional
            If true, starts the cluster in asynchronous mode, where it can be
            used in other async code.
        loop : IOLoop, optional
            The IOLoop instance to use. Defaults to the current loop in
            asynchronous mode, otherwise a background loop is started.

        Returns
        -------
        YarnCluster
        """
        self = super(YarnCluster, cls).__new__(cls)
        skein_client = _get_skein_client(skein_client)
        app = skein_client.connect(app_id)

        self._init_common(
            application_client=app,
            asynchronous=asynchronous,
            loop=loop,
            skein_client=skein_client,
        )
        return self

    def _init_common(
        self,
        spec=None,
        application_client=None,
        host=None,
        port=None,
        dashboard_address=None,
        asynchronous=False,
        loop=None,
        skein_client=None,
    ):
        self.spec = spec
        self.application_client = application_client
        self._scheduler_kwargs = _make_scheduler_kwargs(
            host=host,
            port=port,
            dashboard_address=dashboard_address,
        )
        self._scheduler = None
        self.scheduler_info = {}
        self._requested = set()
        self.scheduler_comm = None
        self._watch_worker_status_task = None
        self._start_task = None
        self._stop_task = None
        self._finalizer = None
        self._adaptive = None
        self._adaptive_options = {}
        self._skein_client = skein_client
        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self._loop_runner.start()

        if not self.asynchronous:
            self._sync(self._start_internal())

    def _start_cluster(self):
        """Start the cluster and initialize state"""

        skein_client = _get_skein_client(self._skein_client)

        if "dask.scheduler" not in self.spec.services:
            # deploy_mode == 'local'
            scheduler_address = self._scheduler.address
            for k in ["dashboard", "bokeh"]:
                if k in self._scheduler.services:
                    dashboard_port = self._scheduler.services[k].port
                    dashboard_host = urlparse(scheduler_address).hostname
                    dashboard_address = "http://%s:%d" % (
                        dashboard_host,
                        dashboard_port,
                    )
                    break
            else:
                dashboard_address = None

            with submit_and_handle_failures(skein_client, self.spec) as app:
                app.kv["dask.scheduler"] = scheduler_address.encode()
                if dashboard_address is not None:
                    app.kv["dask.dashboard"] = dashboard_address.encode()
        else:
            # deploy_mode == 'remote'
            with submit_and_handle_failures(skein_client, self.spec) as app:
                scheduler_address = app.kv.wait("dask.scheduler").decode()
                dashboard_address = app.kv.get("dask.dashboard")
                if dashboard_address is not None:
                    dashboard_address = dashboard_address.decode()

        # Ensure application gets cleaned up
        self._finalizer = weakref.finalize(self, app.shutdown)

        self.scheduler_address = scheduler_address
        self._dashboard_address = dashboard_address
        self.application_client = app

    def _connect_existing(self):
        spec = self.application_client.get_specification()
        if "dask.worker" not in spec.services:
            raise ValueError("%r is not a valid dask cluster" % self.app_id)

        scheduler_address = self.application_client.kv.wait(
            "dask.scheduler").decode()
        dashboard_address = self.application_client.kv.get("dask.dashboard")
        if dashboard_address is not None:
            dashboard_address = dashboard_address.decode()

        self.spec = spec
        self.scheduler_address = scheduler_address
        self._dashboard_address = dashboard_address

    async def _start_internal(self):
        if self._start_task is None:
            self._start_task = asyncio.ensure_future(self._start_async())
        try:
            await self._start_task
        except BaseException:
            # On exception, cleanup
            await self._stop_internal()
            raise
        return self

    async def _start_async(self):
        if self.spec is not None:
            # Start a new cluster
            if "dask.scheduler" not in self.spec.services:
                self._scheduler = Scheduler(
                    loop=self.loop,
                    **self._scheduler_kwargs,
                )
                await self._scheduler
            else:
                self._scheduler = None
            await self.loop.run_in_executor(None, self._start_cluster)
        else:
            # Connect to an existing cluster
            await self.loop.run_in_executor(None, self._connect_existing)

        self.scheduler_comm = rpc(self.scheduler_address)
        comm = None
        try:
            comm = await self.scheduler_comm.live_comm()
            await comm.write({"op": "subscribe_worker_status"})
            self.scheduler_info = await comm.read()
            workers = self.scheduler_info.get("workers", {})
            self._requested.update(w["name"] for w in workers.values())
            self._watch_worker_status_task = asyncio.ensure_future(
                self._watch_worker_status(comm))
        except Exception:
            if comm is not None:
                await comm.close()

    async def _stop_internal(self, status="SUCCEEDED", diagnostics=None):
        if self._stop_task is None:
            self._stop_task = asyncio.ensure_future(
                self._stop_async(status=status, diagnostics=diagnostics))
        await self._stop_task

    async def _stop_async(self, status="SUCCEEDED", diagnostics=None):
        if self._start_task is not None:
            if not self._start_task.done():
                # We're still starting, cancel task
                await cancel_task(self._start_task)
            self._start_task = None

        if self._adaptive is not None:
            self._adaptive.stop()

        if self._watch_worker_status_task is not None:
            await cancel_task(self._watch_worker_status_task)
            self._watch_worker_status_task = None

        if self.scheduler_comm is not None:
            self.scheduler_comm.close_rpc()
            self.scheduler_comm = None

        await self.loop.run_in_executor(
            None,
            lambda: self._stop_sync(status=status, diagnostics=diagnostics))

        if self._scheduler is not None:
            await self._scheduler.close()
            self._scheduler = None

    def _stop_sync(self, status="SUCCEEDED", diagnostics=None):
        if self._finalizer is not None and self._finalizer.peek() is not None:
            self.application_client.shutdown(status=status,
                                             diagnostics=diagnostics)
            self._finalizer.detach()  # don't run the finalizer later
        self._finalizer = None

    def __await__(self):
        return self.__aenter__().__await__()

    async def __aenter__(self):
        return await self._start_internal()

    async def __aexit__(self, typ, value, traceback):
        await self._stop_internal()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __repr__(self):
        return "YarnCluster<%s>" % self.app_id

    @property
    def loop(self):
        return self._loop_runner.loop

    @property
    def app_id(self):
        return self.application_client.id

    @property
    def asynchronous(self):
        return self._asynchronous

    def _sync(self, task):
        if self.asynchronous:
            return task
        future = asyncio.run_coroutine_threadsafe(task, self.loop.asyncio_loop)
        try:
            return future.result()
        except BaseException:
            future.cancel()
            raise

    @cached_property
    def dashboard_link(self):
        """Link to the dask dashboard. None if dashboard isn't running"""
        if self._dashboard_address is None:
            return None
        dashboard = urlparse(self._dashboard_address)
        return format_dashboard_link(dashboard.hostname, dashboard.port)

    def shutdown(self, status="SUCCEEDED", diagnostics=None):
        """Shutdown the application.

        Parameters
        ----------
        status : {'SUCCEEDED', 'FAILED', 'KILLED'}, optional
            The yarn application exit status.
        diagnostics : str, optional
            The application exit message, usually used for diagnosing failures.
            Can be seen in the YARN Web UI for completed applications under
            "diagnostics". If not provided, a default will be used.
        """
        if self.asynchronous:
            return self._stop_internal(status=status, diagnostics=diagnostics)
        if self.loop.asyncio_loop.is_running() and not sys.is_finalizing():
            self._sync(
                self._stop_internal(status=status, diagnostics=diagnostics))
        else:
            # Always run this!
            self._stop_sync(status=status, diagnostics=diagnostics)
        self._loop_runner.stop()

    def close(self, **kwargs):
        """Close this cluster. An alias for ``shutdown``.

        See Also
        --------
        shutdown
        """
        return self.shutdown(**kwargs)

    def __del__(self):
        if not hasattr(self, "_loop_runner"):
            return
        if self.asynchronous:
            # No del for async mode
            return
        self.close()

    @property
    def _observed(self):
        return {w["name"] for w in self.scheduler_info["workers"].values()}

    async def _workers(self):
        return await self.loop.run_in_executor(
            None,
            lambda: self.application_client.get_containers(services=
                                                           ["dask.worker"]),
        )

    def workers(self):
        """A list of all currently running worker containers."""
        return self._sync(self._workers())

    async def _logs(self, scheduler=True, workers=True):
        logs = Logs()

        if scheduler:
            slogs = await self.scheduler_comm.logs()
            logs["Scheduler"] = Log("\n".join(line for level, line in slogs))

        if workers:
            d = await self.scheduler_comm.worker_logs(workers=workers)
            for k, v in d.items():
                logs[k] = Log("\n".join(line for level, line in v))

        return logs

    def logs(self, scheduler=True, workers=True):
        """Return logs for the scheduler and/or workers

        Parameters
        ----------
        scheduler : boolean, optional
            Whether or not to collect logs for the scheduler
        workers : boolean or iterable, optional
            A list of worker addresses to select. Defaults to all workers if
            ``True`` or no workers if ``False``

        Returns
        -------
        logs : dict
            A dictionary of name -> logs.
        """
        return self._sync(self._logs(scheduler=scheduler, workers=workers))

    async def _scale_up(self, n):
        if n > len(self._requested):
            containers = await self.loop.run_in_executor(
                None,
                lambda: self.application_client.scale("dask.worker", n),
            )
            self._requested.update(c.id for c in containers)

    async def _scale_down(self, workers):
        self._requested.difference_update(workers)
        await self.scheduler_comm.retire_workers(names=list(workers))

        def _kill_containers():
            for c in workers:
                try:
                    self.application_client.kill_container(c)
                except ValueError:
                    pass

        await self.loop.run_in_executor(
            None,
            _kill_containers,
        )

    async def _scale(self, n):
        if self._adaptive is not None:
            self._adaptive.stop()
        if n >= len(self._requested):
            return await self._scale_up(n)
        else:
            n_to_delete = len(self._requested) - n
            pending = list(self._requested - self._observed)
            running = list(self._observed.difference(pending))
            to_close = pending[:n_to_delete]
            to_close.extend(running[:n_to_delete - len(to_close)])
            await self._scale_down(to_close)
            self._requested.difference_update(to_close)

    def scale(self, n):
        """Scale cluster to n workers.

        Parameters
        ----------
        n : int
            Target number of workers

        Examples
        --------
        >>> cluster.scale(10)  # scale cluster to ten workers
        """
        return self._sync(self._scale(n))

    def adapt(self,
              minimum=0,
              maximum=math.inf,
              interval="1s",
              wait_count=3,
              target_duration="5s",
              **kwargs):
        """Turn on adaptivity

        This scales Dask clusters automatically based on scheduler activity.

        Parameters
        ----------
        minimum : int, optional
            Minimum number of workers. Defaults to ``0``.
        maximum : int, optional
            Maximum number of workers. Defaults to ``inf``.
        interval : timedelta or str, optional
            Time between worker add/remove recommendations.
        wait_count : int, optional
            Number of consecutive times that a worker should be suggested for
            removal before we remove it.
        target_duration : timedelta or str, optional
            Amount of time we want a computation to take. This affects how
            aggressively we scale up.
        **kwargs :
            Additional parameters to pass to
            ``distributed.Scheduler.workers_to_close``.

        Examples
        --------
        >>> cluster.adapt(minimum=0, maximum=10)
        """
        if self._adaptive is not None:
            self._adaptive.stop()
        self._adaptive_options.update(
            minimum=minimum,
            maximum=maximum,
            interval=interval,
            wait_count=wait_count,
            target_duration=target_duration,
            **kwargs,
        )
        self._adaptive = Adaptive(self, **self._adaptive_options)

    async def _watch_worker_status(self, comm):
        # We don't want to hold on to a ref to self, otherwise this will
        # leave a dangling reference and prevent garbage collection.
        ref_self = weakref.ref(self)
        self = None
        try:
            while True:
                try:
                    msgs = await comm.read()
                except OSError:
                    break
                try:
                    self = ref_self()
                    if self is None:
                        break
                    for op, msg in msgs:
                        if op == "add":
                            workers = msg.pop("workers")
                            self.scheduler_info["workers"].update(workers)
                            self.scheduler_info.update(msg)
                            self._requested.update(w["name"]
                                                   for w in workers.values())
                        elif op == "remove":
                            del self.scheduler_info["workers"][msg]
                            self._requested.discard(msg)
                    if hasattr(self, "_status_widget"):
                        self._status_widget.value = self._widget_status()
                finally:
                    self = None
        finally:
            await comm.close()

    def _widget_status(self):
        try:
            workers = self.scheduler_info["workers"]
        except KeyError:
            return None
        else:
            n_workers = len(workers)
            cores = sum(w["nthreads"] for w in workers.values())
            memory = sum(w["memory_limit"] for w in workers.values())

            return _widget_status_template % (n_workers, cores,
                                              format_bytes(memory))

    def _widget(self):
        """ Create IPython widget for display within a notebook """
        try:
            return self._cached_widget
        except AttributeError:
            pass

        if self.asynchronous:
            return None

        try:
            from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion
        except ImportError:
            self._cached_widget = None
            return None

        layout = Layout(width="150px")

        title = HTML("<h2>YarnCluster</h2>")

        status = HTML(self._widget_status(), layout=Layout(min_width="150px"))

        request = IntText(0, description="Workers", layout=layout)
        scale = Button(description="Scale", layout=layout)

        minimum = IntText(0, description="Minimum", layout=layout)
        maximum = IntText(0, description="Maximum", layout=layout)
        adapt = Button(description="Adapt", layout=layout)

        accordion = Accordion(
            [HBox([request, scale]),
             HBox([minimum, maximum, adapt])],
            layout=Layout(min_width="500px"),
        )
        accordion.selected_index = None
        accordion.set_title(0, "Manual Scaling")
        accordion.set_title(1, "Adaptive Scaling")

        @adapt.on_click
        def adapt_cb(b):
            self.adapt(minimum=minimum.value, maximum=maximum.value)

        @scale.on_click
        def scale_cb(b):
            with log_errors():
                self.scale(request.value)

        app_id = HTML("<p><b>Application ID: </b>{0}</p>".format(self.app_id))

        elements = [title, HBox([status, accordion]), app_id]

        if self.dashboard_link is not None:
            link = HTML(
                '<p><b>Dashboard: </b><a href="{0}" target="_blank">{0}'
                "</a></p>\n".format(self.dashboard_link))
            elements.append(link)

        self._cached_widget = box = VBox(elements)
        self._status_widget = status

        return box

    def _ipython_display_(self, **kwargs):
        widget = self._widget()
        if widget is not None:
            return widget._ipython_display_(**kwargs)
        else:
            from IPython.display import display

            data = {"text/plain": repr(self), "text/html": self._repr_html_()}
            display(data, raw=True)

    def _repr_html_(self):
        if self.dashboard_link is not None:
            dashboard = "<a href='{0}' target='_blank'>{0}</a>".format(
                self.dashboard_link)
        else:
            dashboard = "Not Available"
        return (
            "<div style='background-color: #f2f2f2; display: inline-block; "
            "padding: 10px; border: 1px solid #999999;'>\n"
            "  <h3>YarnCluster</h3>\n"
            "  <ul>\n"
            "    <li><b>Application ID: </b>{app_id}\n"
            "    <li><b>Dashboard: </b>{dashboard}\n"
            "  </ul>\n"
            "</div>\n").format(app_id=self.app_id, dashboard=dashboard)
Ejemplo n.º 8
0
class AzureMLCluster(Cluster):
    """ Deploy a Dask cluster using Azure ML

    This creates a dask scheduler and workers on an Azure ML Compute Target.

    Parameters
    ----------
    workspace: azureml.core.Workspace (required)
        Azure ML Workspace - see https://aka.ms/azureml/workspace.

    vm_size: str (optional)
        Azure VM size to be used in the Compute Target - see https://aka.ms/azureml/vmsizes.

    datastores: List[Datastore] (optional)
        List of Azure ML Datastores to be mounted on the headnode -
        see https://aka.ms/azureml/data and https://aka.ms/azureml/datastores.

        Defaults to ``[]``. To mount all datastores in the workspace,
        set to ``ws.datastores.values()``.

    environment_definition: azureml.core.Environment (optional)
        Azure ML Environment - see https://aka.ms/azureml/environments.

        Defaults to the "AzureML-Dask-CPU" or "AzureML-Dask-GPU" curated environment.

    scheduler_idle_timeout: int (optional)
        Number of idle seconds leading to scheduler shut down.

        Defaults to ``1200`` (20 minutes).

    experiment_name: str (optional)
        The name of the Azure ML Experiment used to control the cluster.

        Defaults to ``dask-cloudprovider``.

    initial_node_count: int (optional)
        The initial number of nodes for the Dask Cluster.

        Defaults to ``1``.

    jupyter: bool (optional)
        Flag to start JupyterLab session on the headnode of the cluster.

        Defaults to ``False``.

    jupyter_port: int (optional)
        Port on headnode to use for hosting JupyterLab session.

        Defaults to ``9000``.

    dashboard_port: int (optional)
        Port on headnode to use for hosting Dask dashboard.

        Defaults to ``9001``.

    scheduler_port: int (optional)
        Port to map the scheduler port to via SSH-tunnel if machine not on the same VNET.

        Defaults to ``9002``.

    worker_death_timeout: int (optional)
        Number of seconds to wait for a worker to respond before removing it.

        Defaults to ``30``.

    additional_ports: list[tuple[int, int]] (optional)
        Additional ports to forward. This requires a list of tuples where the first element
        is the port to open on the headnode while the second element is the port to map to
        or forward via the SSH-tunnel.

        Defaults to ``[]``.

    compute_target: azureml.core.ComputeTarget (optional)
        Azure ML Compute Target - see https://aka.ms/azureml/computetarget.

    admin_username: str (optional)
        Username of the admin account for the AzureML Compute.
        Required for runs that are not on the same VNET. Defaults to empty string.
        Throws Exception if machine not on the same VNET.

        Defaults to ``""``.

    admin_ssh_key: str (optional)
        Location of the SSH secret key used when creating the AzureML Compute.
        The key should be passwordless if run from a Jupyter notebook.
        The ``id_rsa`` file needs to have 0700 permissions set.
        Required for runs that are not on the same VNET. Defaults to empty string.
        Throws Exception if machine not on the same VNET.

        Defaults to ``""``.

    vnet: str (optional)
        Name of the virtual network.

    subnet: str (optional)
        Name of the subnet inside the virtual network ``vnet``.

    vnet_resource_group: str (optional)
        Name of the resource group where the virtual network ``vnet``
        is located. If not passed, but names for ``vnet`` and ``subnet`` are
        passed, ``vnet_resource_group`` is assigned with the name of resource
        group associated with ``workspace``

    telemetry_opt_out: bool (optional)
        A boolean parameter. Defaults to logging a version of AzureMLCluster
        with Microsoft. Set this flag to False if you do not want to share this
        information with Microsoft. Microsoft is not tracking anything else you
        do in your Dask cluster nor any other information related to your
        workload.

    asynchronous: bool (optional)
        Flag to run jobs asynchronously.

    **kwargs: dict
        Additional keyword arguments.
    """

    def __init__(
        self,
        workspace,
        compute_target=None,
        environment_definition=None,
        experiment_name=None,
        initial_node_count=None,
        jupyter=None,
        jupyter_port=None,
        dashboard_port=None,
        scheduler_port=None,
        scheduler_idle_timeout=None,
        worker_death_timeout=None,
        additional_ports=None,
        admin_username=None,
        admin_ssh_key=None,
        datastores=None,
        code_store=None,
        vnet_resource_group=None,
        vnet=None,
        subnet=None,
        show_output=False,
        telemetry_opt_out=None,
        asynchronous=False,
        **kwargs,
    ):
        ### REQUIRED PARAMETERS
        self.workspace = workspace
        self.compute_target = compute_target

        ### ENVIRONMENT
        self.environment_definition = environment_definition

        ### EXPERIMENT DEFINITION
        self.experiment_name = experiment_name
        self.tags = {"tag": "azureml-dask"}

        ### ENVIRONMENT AND VARIABLES
        self.initial_node_count = initial_node_count

        ### SEND TELEMETRY
        self.telemetry_opt_out = telemetry_opt_out
        self.telemetry_set = False

        ### FUTURE EXTENSIONS
        self.kwargs = kwargs
        self.show_output = show_output

        ## CREATE COMPUTE TARGET
        self.admin_username = admin_username
        self.admin_ssh_key = admin_ssh_key
        self.vnet_resource_group = vnet_resource_group
        self.vnet = vnet
        self.subnet = subnet
        self.compute_target_set = True
        self.pub_key_file = ""
        self.pri_key_file = ""
        if self.compute_target is None:
            try:
                self.compute_target = self.__create_compute_target()
                self.compute_target_set = False
            except Exception as e:
                logger.exception(e)
                return
        elif self.compute_target.admin_user_ssh_key is not None and (
            self.admin_ssh_key is None or self.admin_username is None
        ):
            logger.exception(
                "Please provide private key and admin username to access compute target {}".format(
                    self.compute_target.name
                )
            )
            return

        ### GPU RUN INFO
        self.workspace_vm_sizes = AmlCompute.supported_vmsizes(self.workspace)
        self.workspace_vm_sizes = [
            (e["name"].lower(), e["gpus"]) for e in self.workspace_vm_sizes
        ]
        self.workspace_vm_sizes = dict(self.workspace_vm_sizes)

        self.compute_target_vm_size = self.compute_target.serialize()["properties"][
            "status"
        ]["vmSize"].lower()
        self.n_gpus_per_node = self.workspace_vm_sizes[self.compute_target_vm_size]
        self.use_gpu = True if self.n_gpus_per_node > 0 else False
        if self.environment_definition is None:
            if self.use_gpu:
                self.environment_definition = self.workspace.environments[
                    "AzureML-Dask-GPU"
                ]
            else:
                self.environment_definition = self.workspace.environments[
                    "AzureML-Dask-CPU"
                ]

        ### JUPYTER AND PORT FORWARDING
        self.jupyter = jupyter
        self.jupyter_port = jupyter_port
        self.dashboard_port = dashboard_port
        self.scheduler_port = scheduler_port
        self.scheduler_idle_timeout = scheduler_idle_timeout
        self.portforward_proc = None
        self.worker_death_timeout = worker_death_timeout
        self.end_logging = False  # FLAG FOR STOPPING THE port_forward_logger THREAD

        if additional_ports is not None:
            if type(additional_ports) != list:
                error_message = (
                    f"The additional_ports parameter is of {type(additional_ports)}"
                    " type but needs to be a list of int tuples."
                    " Check the documentation."
                )
                logger.exception(error_message)
                raise TypeError(error_message)

            if len(additional_ports) > 0:
                if type(additional_ports[0]) != tuple:
                    error_message = (
                        f"The additional_ports elements are of {type(additional_ports[0])}"
                        " type but needs to be a list of int tuples."
                        " Check the documentation."
                    )
                    raise TypeError(error_message)

                ### check if all elements are tuples of length two and int type
                all_correct = True
                for el in additional_ports:
                    if type(el) != tuple or len(el) != 2:
                        all_correct = False
                        break

                    if (type(el[0]), type(el[1])) != (int, int):
                        all_correct = False
                        break

                if not all_correct:
                    error_message = (
                        "At least one of the elements of the additional_ports parameter"
                        " is wrong. Make sure it is a list of int tuples."
                        " Check the documentation."
                    )
                    raise TypeError(error_message)

        self.additional_ports = additional_ports
        self.scheduler_ip_port = (
            None  ### INIT FOR HOLDING THE ADDRESS FOR THE SCHEDULER
        )

        ### DATASTORES
        self.datastores = datastores

        ### RUNNING IN MATRIX OR LOCAL
        self.same_vnet = None
        self.is_in_ci = False

        ### GET RUNNING LOOP
        self._loop_runner = LoopRunner(loop=None, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        self.abs_path = pathlib.Path(__file__).parent.absolute()

        ### INITIALIZE CLUSTER
        super().__init__(asynchronous=asynchronous)

        if not self.asynchronous:
            self._loop_runner.start()
            self.sync(self.__get_defaults)

            if not self.telemetry_opt_out:
                self.__append_telemetry()

            self.sync(self.__create_cluster)

    async def __get_defaults(self):
        self.config = dask.config.get("cloudprovider.azure", {})

        if self.experiment_name is None:
            self.experiment_name = self.config.get("experiment_name")

        if self.initial_node_count is None:
            self.initial_node_count = self.config.get("initial_node_count")

        if self.jupyter is None:
            self.jupyter = self.config.get("jupyter")

        if self.jupyter_port is None:
            self.jupyter_port = self.config.get("jupyter_port")

        if self.dashboard_port is None:
            self.dashboard_port = self.config.get("dashboard_port")

        if self.scheduler_port is None:
            self.scheduler_port = self.config.get("scheduler_port")

        if self.scheduler_idle_timeout is None:
            self.scheduler_idle_timeout = self.config.get("scheduler_idle_timeout")

        if self.worker_death_timeout is None:
            self.worker_death_timeout = self.config.get("worker_death_timeout")

        if self.additional_ports is None:
            self.additional_ports = self.config.get("additional_ports")

        if self.admin_username is None:
            self.admin_username = self.config.get("admin_username")

        if self.admin_ssh_key is None:
            self.admin_ssh_key = self.config.get("admin_ssh_key")

        if self.datastores is None:
            self.datastores = self.config.get("datastores")

        if self.telemetry_opt_out is None:
            self.telemetry_opt_out = self.config.get("telemetry_opt_out")

        ### PARAMETERS TO START THE CLUSTER
        self.scheduler_params = {}
        self.worker_params = {}

        ### scheduler and worker parameters
        self.scheduler_params["--jupyter"] = self.jupyter
        self.scheduler_params["--scheduler_idle_timeout"] = self.scheduler_idle_timeout
        self.worker_params["--worker_death_timeout"] = self.worker_death_timeout

        if self.use_gpu:
            self.scheduler_params["--use_gpu"] = True
            self.scheduler_params["--n_gpus_per_node"] = self.n_gpus_per_node
            self.worker_params["--use_gpu"] = True
            self.worker_params["--n_gpus_per_node"] = self.n_gpus_per_node

        ### CLUSTER PARAMS
        self.max_nodes = self.compute_target.serialize()["properties"]["properties"][
            "scaleSettings"
        ]["maxNodeCount"]
        self.scheduler_ip_port = None
        self.workers_list = []
        self.URLs = {}

        ### SANITY CHECKS
        ###-----> initial node count
        if self.initial_node_count > self.max_nodes:
            self.initial_node_count = self.max_nodes

    def __append_telemetry(self):
        if not self.telemetry_set:
            self.telemetry_set = True
            try:
                from azureml._base_sdk_common.user_agent import append

                append("AzureMLCluster-DASK", "0.1")
            except ImportError:
                pass

    def __print_message(self, msg, length=80, filler="#", pre_post=""):
        logger.info(msg)
        if self.show_output:
            print(f"{pre_post} {msg} {pre_post}".center(length, filler))

    async def __check_if_scheduler_ip_reachable(self):
        """
        Private method to determine if running in the cloud within the same VNET
        and the scheduler node is reachable
        """
        try:
            ip, port = self.scheduler_ip_port.split(":")
            socket.create_connection((ip, port), 20)
            self.same_vnet = True
            self.__print_message("On the same VNET")
        except socket.timeout as e:
            self.__print_message("Not on the same VNET")
            self.same_vnet = False
        except ConnectionRefusedError as e:
            logger.info(e)
            self.__print_message(e)
            pass

    def __prepare_rpc_connection_to_headnode(self):
        if self.same_vnet:
            return self.run.get_metrics()["scheduler"]
        elif self.is_in_ci:
            uri = f"{self.hostname}:{self.scheduler_port}"
            return uri
        else:
            uri = f"localhost:{self.scheduler_port}"
            self.hostname = "localhost"
            logger.info(f"Local connection: {uri}")
            return uri

    def __get_ssh_keys(self):
        from cryptography.hazmat.primitives import serialization as crypto_serialization
        from cryptography.hazmat.primitives.asymmetric import rsa
        from cryptography.hazmat.backends import (
            default_backend as crypto_default_backend,
        )

        dir_path = os.path.join(os.getcwd(), "tmp")
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        pub_key_file = os.path.join(dir_path, "key.pub")
        pri_key_file = os.path.join(dir_path, "key")

        key = rsa.generate_private_key(
            backend=crypto_default_backend(), public_exponent=65537, key_size=2048
        )
        private_key = key.private_bytes(
            crypto_serialization.Encoding.PEM,
            crypto_serialization.PrivateFormat.PKCS8,
            crypto_serialization.NoEncryption(),
        )
        public_key = key.public_key().public_bytes(
            crypto_serialization.Encoding.OpenSSH,
            crypto_serialization.PublicFormat.OpenSSH,
        )

        with open(pub_key_file, "wb") as f:
            f.write(public_key)

        with open(pri_key_file, "wb") as f:
            f.write(private_key)

        os.chmod(pri_key_file, 0o600)

        with open(pub_key_file, "r") as f:
            pubkey = f.read()

        self.pub_key_file = pub_key_file
        self.pri_key_file = pri_key_file

        return pubkey, pri_key_file

    def __create_compute_target(self):
        import random

        tmp_name = "dask-ct-{}".format(random.randint(100000, 999999))
        ct_name = self.kwargs.get("ct_name", tmp_name)
        vm_name = self.kwargs.get("vm_size", "STANDARD_DS3_V2")
        min_nodes = int(self.kwargs.get("min_nodes", "0"))
        max_nodes = int(self.kwargs.get("max_nodes", "100"))
        idle_time = int(self.kwargs.get("idle_time", "300"))
        vnet_rg = None
        vnet_name = None
        subnet_name = None

        if self.admin_username is None:
            self.admin_username = "******"
        ssh_key_pub, self.admin_ssh_key = self.__get_ssh_keys()

        if self.vnet and self.subnet:
            vnet_name = self.vnet
            subnet_name = self.subnet
            if self.vnet_resource_group:
                vnet_rg = self.vnet_resource_group
            else:
                vnet_rg = self.workspace.resource_group

        try:
            if ct_name not in self.workspace.compute_targets:
                config = AmlCompute.provisioning_configuration(
                    vm_size=vm_name,
                    min_nodes=min_nodes,
                    max_nodes=max_nodes,
                    vnet_resourcegroup_name=vnet_rg,
                    vnet_name=vnet_name,
                    subnet_name=subnet_name,
                    idle_seconds_before_scaledown=idle_time,
                    admin_username=self.admin_username,
                    admin_user_ssh_key=ssh_key_pub,
                    remote_login_port_public_access="Enabled",
                )

                self.__print_message("Creating new compute targe: {}".format(ct_name))
                ct = ComputeTarget.create(self.workspace, ct_name, config)
                ct.wait_for_completion(show_output=self.show_output)
            else:
                self.__print_message(
                    "Using existing compute target: {}".format(ct_name)
                )
                ct = self.workspace.compute_targets[ct_name]
        except Exception as e:
            logger.exception("Cannot create/get compute target. {}".format(e))
            raise e

        return ct

    def __delete_compute_target(self):
        try:
            self.compute_target.delete()
        except ComputeTargetException as e:
            logger.exception(
                "Compute target {} cannot be removed. You may need to delete it manually. {}".format(
                    self.compute_target.name, e
                )
            )

    async def __create_cluster(self):
        self.__print_message("Setting up cluster")
        exp = Experiment(self.workspace, self.experiment_name)
        estimator = Estimator(
            os.path.join(self.abs_path, "setup"),
            compute_target=self.compute_target,
            entry_script="start_scheduler.py",
            environment_definition=self.environment_definition,
            script_params=self.scheduler_params,
            node_count=1,  ### start only scheduler
            distributed_training=MpiConfiguration(),
            use_docker=True,
            inputs=self.datastores,
        )

        run = exp.submit(estimator, tags=self.tags)

        self.__print_message("Waiting for scheduler node's IP")
        status = run.get_status()
        while (
            status != "Canceled"
            and status != "Failed"
            and "scheduler" not in run.get_metrics()
        ):
            print(".", end="")
            logger.info("Scheduler not ready")
            time.sleep(5)
            status = run.get_status()

        if status == "Canceled" or status == "Failed":
            run_error = run.get_details().get("error")
            error_message = "Failed to start the AzureML cluster."

            if run_error:
                error_message = "{} {}".format(error_message, run_error)
            logger.exception(error_message)

            if not self.compute_target_set:
                self.__delete_compute_target()

            raise Exception(error_message)

        print("\n")

        ### SET FLAGS
        self.scheduler_ip_port = run.get_metrics()["scheduler"]
        self.worker_params["--scheduler_ip_port"] = self.scheduler_ip_port
        self.__print_message(f'Scheduler: {run.get_metrics()["scheduler"]}')
        self.run = run

        ### CHECK IF ON THE SAME VNET
        max_retry = 5
        while self.same_vnet is None and max_retry > 0:
            time.sleep(5)
            await self.sync(self.__check_if_scheduler_ip_reachable)
            max_retry -= 1

        if self.same_vnet is None:
            self.run.cancel()
            if not self.compute_target_set:
                self.__delete_compute_target()
            logger.exception(
                "Connection error after retrying. Failed to start the AzureML cluster."
            )
            return

        ### REQUIRED BY dask.distributed.deploy.cluster.Cluster
        self.hostname = socket.gethostname()
        self.is_in_ci = (
            f"/mnt/batch/tasks/shared/LS_root/mounts/clusters/{self.hostname}"
            in os.getcwd()
        )
        _scheduler = self.__prepare_rpc_connection_to_headnode()
        self.scheduler_comm = rpc(_scheduler)
        await self.sync(self.__setup_port_forwarding)

        try:
            await super()._start()
        except Exception as e:
            logger.exception(e)
            # CLEAN UP COMPUTE TARGET
            self.run.cancel()
            if not self.compute_target_set:
                self.__delete_compute_target()
            return

        await self.sync(self.__update_links)

        self.__print_message("Connections established")
        self.__print_message(f"Scaling to {self.initial_node_count} workers")

        if self.initial_node_count > 1:
            self.scale(
                self.initial_node_count
            )  # LOGIC TO KEEP PROPER TRACK OF WORKERS IN `scale`
        self.__print_message("Scaling is done")

    async def __update_links(self):
        token = self.run.get_metrics()["token"]

        if self.same_vnet or self.is_in_ci:
            location = self.workspace.get_details()["location"]

            self.scheduler_info[
                "dashboard_url"
            ] = f"https://{self.hostname}-{self.dashboard_port}.{location}.instances.azureml.net/status"

            self.scheduler_info[
                "jupyter_url"
            ] = f"https://{self.hostname}-{self.jupyter_port}.{location}.instances.azureml.net/lab?token={token}"
        else:
            hostname = "localhost"
            self.scheduler_info[
                "dashboard_url"
            ] = f"http://{hostname}:{self.dashboard_port}"
            self.scheduler_info[
                "jupyter_url"
            ] = f"http://{hostname}:{self.jupyter_port}/?token={token}"

        logger.info(f'Dashboard URL: {self.scheduler_info["dashboard_url"]}')
        logger.info(f'Jupyter URL:   {self.scheduler_info["jupyter_url"]}')

    def __port_forward_logger(self, portforward_proc):
        portforward_log = open("portforward_out_log.txt", "w")

        while True:
            portforward_out = portforward_proc.stdout.readline()
            if portforward_proc != "":
                portforward_log.write(portforward_out)
                portforward_log.flush()

            if self.end_logging:
                break
        return

    async def __setup_port_forwarding(self):
        dashboard_address = self.run.get_metrics()["dashboard"]
        jupyter_address = self.run.get_metrics()["jupyter"]
        scheduler_ip = self.run.get_metrics()["scheduler"].split(":")[0]

        self.__print_message("Running in compute instance? {}".format(self.is_in_ci))
        os.system(
            "killall socat"
        )  # kill all socat processes - cleans up previous port forward setups
        if self.same_vnet:
            os.system(
                f"setsid socat tcp-listen:{self.dashboard_port},reuseaddr,fork tcp:{dashboard_address} &"
            )
            os.system(
                f"setsid socat tcp-listen:{self.jupyter_port},reuseaddr,fork tcp:{jupyter_address} &"
            )

            ### map additional ports
            for port in self.additional_ports:
                os.system(
                    f"setsid socat tcp-listen:{self.port[1]},reuseaddr,fork tcp:{scheduler_ip}:{port[0]} &"
                )
        else:
            scheduler_public_ip = self.compute_target.list_nodes()[0]["publicIpAddress"]
            scheduler_public_port = self.compute_target.list_nodes()[0]["port"]
            self.__print_message("scheduler_public_ip: {}".format(scheduler_public_ip))
            self.__print_message(
                "scheduler_public_port: {}".format(scheduler_public_port)
            )

            host_ip = "0.0.0.0"
            if self.is_in_ci:
                host_ip = socket.gethostbyname(self.hostname)

            cmd = (
                "ssh -vvv -o StrictHostKeyChecking=no -N"
                f" -i {os.path.expanduser(self.admin_ssh_key)}"
                f" -L {host_ip}:{self.jupyter_port}:{scheduler_ip}:8888"
                f" -L {host_ip}:{self.dashboard_port}:{scheduler_ip}:8787"
                f" -L {host_ip}:{self.scheduler_port}:{scheduler_ip}:8786"
            )

            for port in self.additional_ports:
                cmd += f" -L {host_ip}:{port[1]}:{scheduler_ip}:{port[0]}"

            cmd += f" {self.admin_username}@{scheduler_public_ip} -p {scheduler_public_port}"

            self.portforward_proc = subprocess.Popen(
                cmd.split(),
                universal_newlines=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
            )

            ### Starting thread to keep the SSH tunnel open on Windows
            portforward_logg = threading.Thread(
                target=self.__port_forward_logger, args=[self.portforward_proc]
            )
            portforward_logg.start()

    @property
    def dashboard_link(self):
        """ Link to Dask dashboard.
        """
        try:
            link = self.scheduler_info["dashboard_url"]
        except KeyError:
            return ""
        else:
            return link

    @property
    def jupyter_link(self):
        """ Link to JupyterLab on running on the headnode of the cluster.
        Set ``jupyter=True`` when creating the ``AzureMLCluster``.
        """
        try:
            link = self.scheduler_info["jupyter_url"]
        except KeyError:
            return ""
        else:
            return link

    def _format_nodes(self, nodes, requested, use_gpu, n_gpus_per_node=None):
        if use_gpu:
            if nodes == requested:
                return f"{nodes}"
            else:
                return f"{nodes} / {requested}"
        else:
            if nodes == requested:
                return f"{nodes}"
            else:
                return f"{nodes} / {requested}"

    def _widget_status(self):
        ### reporting proper number of nodes vs workers in a multi-GPU worker scenario
        nodes = len(self.scheduler_info["workers"])

        if self.use_gpu:
            nodes = int(nodes / self.n_gpus_per_node)
        if hasattr(self, "worker_spec"):
            requested = sum(
                1 if "group" not in each else len(each["group"])
                for each in self.worker_spec.values()
            )

        elif hasattr(self, "nodes"):
            requested = len(self.nodes)
        else:
            requested = nodes

        nodes = self._format_nodes(nodes, requested, self.use_gpu, self.n_gpus_per_node)

        cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values())
        cores_or_gpus = "Workers (GPUs)" if self.use_gpu else "Workers (vCPUs)"

        memory = (
            sum(
                v["gpu"]["memory-total"][0]
                for v in self.scheduler_info["workers"].values()
            )
            if self.use_gpu
            else sum(v["memory_limit"] for v in self.scheduler_info["workers"].values())
        )
        memory = format_bytes(memory)

        text = """
<div>
  <style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }
    .dataframe tbody tr th {
        vertical-align: top;
    }
    .dataframe thead th {
        text-align: right;
    }
  </style>
  <table style="text-align: right;">
    <tr> <th>Nodes</th> <td>%s</td></tr>
    <tr> <th>%s</th> <td>%s</td></tr>
    <tr> <th>Memory</th> <td>%s</td></tr>
  </table>
</div>
""" % (
            nodes,
            cores_or_gpus,
            cores,
            memory,
        )
        return text

    def _widget(self):
        """ Create IPython widget for display within a notebook """
        try:
            return self._cached_widget
        except AttributeError:
            pass

        try:
            from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion
        except ImportError:
            self._cached_widget = None
            return None

        layout = Layout(width="150px")

        if self.dashboard_link:
            dashboard_link = (
                '<p><b>Dashboard: </b><a href="%s" target="_blank">%s</a></p>\n'
                % (self.dashboard_link, self.dashboard_link)
            )
        else:
            dashboard_link = ""

        if self.jupyter_link:
            jupyter_link = (
                '<p><b>Jupyter: </b><a href="%s" target="_blank">%s</a></p>\n'
                % (self.jupyter_link, self.jupyter_link)
            )
        else:
            jupyter_link = ""

        title = "<h2>%s</h2>" % self._cluster_class_name
        title = HTML(title)
        dashboard = HTML(dashboard_link)
        jupyter = HTML(jupyter_link)

        status = HTML(self._widget_status(), layout=Layout(min_width="150px"))

        if self._supports_scaling:
            request = IntText(
                self.initial_node_count, description="Nodes", layout=layout
            )
            scale = Button(description="Scale", layout=layout)

            minimum = IntText(0, description="Minimum", layout=layout)
            maximum = IntText(0, description="Maximum", layout=layout)
            adapt = Button(description="Adapt", layout=layout)

            accordion = Accordion(
                [HBox([request, scale]), HBox([minimum, maximum, adapt])],
                layout=Layout(min_width="500px"),
            )
            accordion.selected_index = None
            accordion.set_title(0, "Manual Scaling")
            accordion.set_title(1, "Adaptive Scaling")

            def adapt_cb(b):
                self.adapt(minimum=minimum.value, maximum=maximum.value)
                update()

            adapt.on_click(adapt_cb)

            def scale_cb(b):
                with log_errors():
                    n = request.value
                    with suppress(AttributeError):
                        self._adaptive.stop()
                    self.scale(n)
                    update()

            scale.on_click(scale_cb)
        else:
            accordion = HTML("")

        box = VBox([title, HBox([status, accordion]), jupyter, dashboard])

        self._cached_widget = box

        def update():
            self.close_when_disconnect()
            status.value = self._widget_status()

        pc = PeriodicCallback(update, 500)  # , io_loop=self.loop)
        self.periodic_callbacks["cluster-repr"] = pc
        pc.start()

        return box

    def close_when_disconnect(self):
        status = self.run.get_status()
        if status == "Canceled" or status == "Completed" or status == "Failed":
            self.close()

    def scale(self, workers=1):
        """ Scale the cluster. Scales to a maximum of the workers available in the cluster.
        """
        if workers <= 0:
            self.close()
            return

        count = len(self.workers_list) + 1  # one more worker in head node

        if count < workers:
            self.scale_up(workers - count)
        elif count > workers:
            self.scale_down(count - workers)
        else:
            self.__print_message(f"Number of workers: {workers}")

    # scale up
    def scale_up(self, workers=1):
        """ Scale up the number of workers.
        """
        run_config = RunConfiguration()
        run_config.target = self.compute_target
        run_config.environment = self.environment_definition

        scheduler_ip = self.run.get_metrics()["scheduler"]
        args = [
            f"--scheduler_ip_port={scheduler_ip}",
            f"--use_gpu={self.use_gpu}",
            f"--n_gpus_per_node={self.n_gpus_per_node}",
            f"--worker_death_timeout={self.worker_death_timeout}",
        ]

        child_run_config = ScriptRunConfig(
            source_directory=os.path.join(self.abs_path, "setup"),
            script="start_worker.py",
            arguments=args,
            run_config=run_config,
        )

        for i in range(workers):
            child_run = self.run.submit_child(child_run_config, tags=self.tags)
            self.workers_list.append(child_run)

    # scale down
    def scale_down(self, workers=1):
        """ Scale down the number of workers. Scales to minimum of 1.
        """
        for i in range(workers):
            if self.workers_list:
                child_run = self.workers_list.pop(0)  # deactivate oldest workers
                child_run.complete()  # complete() will mark the run "Complete", but won't kill the process
                child_run.cancel()
            else:
                self.__print_message("All scaled workers are removed.")

    # close cluster
    async def _close(self):
        if self.status == "closed":
            return

        while self.workers_list:
            child_run = self.workers_list.pop()
            child_run.complete()
            child_run.cancel()

        if self.run:
            self.run.complete()
            self.run.cancel()

        self.status = "closed"
        self.__print_message("Scheduler and workers are disconnected.")

        if self.portforward_proc is not None:
            ### STOP LOGGING SSH
            self.portforward_proc.terminate()
            self.end_logging = True

        ### REMOVE TEMP FILE
        if os.path.isfile(self.pub_key_file):
            os.remove(self.pub_key_file)
        if os.path.isfile(self.pri_key_file):
            os.remove(self.pri_key_file)

        if not self.compute_target_set:
            ### REMOVE COMPUTE TARGET
            self.__delete_compute_target()

        time.sleep(30)
        await super()._close()

    def close(self):
        """ Close the cluster. All Azure ML Runs corresponding to the scheduler
        and worker processes will be completed. The Azure ML Compute Target will
        return to its minimum number of nodes after its idle time before scaledown.
        """
        return self.sync(self._close)
Ejemplo n.º 9
0
class Gateway(object):
    """A client for a Dask Gateway Server.

    Parameters
    ----------
    address : str, optional
        The address to the gateway server.
    proxy_address : str, int, optional
        The address of the scheduler proxy server. If an int, it's used as the
        port, with the host/ip taken from ``address``. Provide a full address
        if a different host/ip should be used.
    auth : GatewayAuth, optional
        The authentication method to use.
    asynchronous : bool, optional
        If true, starts the client in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.
    """
    def __init__(self,
                 address=None,
                 proxy_address=None,
                 auth=None,
                 asynchronous=False,
                 loop=None):
        if address is None:
            address = dask.config.get("gateway.address")
        if address is None:
            raise ValueError(
                "No dask-gateway address provided or found in configuration")
        address = address.rstrip("/")

        if proxy_address is None:
            proxy_address = dask.config.get("gateway.proxy-address")
        if proxy_address is None:
            raise ValueError(
                "No dask-gateway proxy address provided or found in configuration"
            )
        if isinstance(proxy_address, int):
            parsed = urlparse(address)
            proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address)
        elif isinstance(proxy_address, str):
            parsed = urlparse(proxy_address)
            proxy_netloc = parsed.netloc if parsed.netloc else proxy_address
        proxy_address = "gateway://%s" % proxy_netloc

        self.address = address
        self.proxy_address = proxy_address

        self._auth = get_auth(auth)
        self._cookie_jar = CookieJar()

        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        self._loop_runner.start()
        if self._asynchronous:
            self._started = self._start()
        else:
            self.sync(self._start)

    async def _start(self):
        self._http_client = AsyncHTTPClient()

    def close(self):
        """Close this gateway client"""
        if not self.asynchronous:
            self._loop_runner.stop()

    def __del__(self):
        # __del__ is still called, even if __init__ failed. Only close if init
        # actually succeeded.
        if hasattr(self, "_started"):
            self.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __await__(self):
        return self._started.__await__()

    async def __aenter__(self):
        await self._started
        return self

    async def __aexit__(self, typ, value, traceback):
        pass

    def __repr__(self):
        return "Gateway<%s>" % self.address

    @property
    def asynchronous(self):
        return (self._asynchronous
                or getattr(thread_state, "asynchronous", False)
                or (hasattr(self.loop, "_thread_identity")
                    and self.loop._thread_identity == get_ident()))

    def sync(self, func, *args, **kwargs):
        if kwargs.pop("asynchronous", None) or self.asynchronous:
            callback_timeout = kwargs.pop("callback_timeout", None)
            future = func(*args, **kwargs)
            if callback_timeout is not None:
                future = gen.with_timeout(timedelta(seconds=callback_timeout),
                                          future)
            return future
        else:
            return sync(self.loop, func, *args, **kwargs)

    async def _fetch(self, req):
        try:
            self._cookie_jar.pre_request(req)
            resp = await self._http_client.fetch(req, raise_error=False)
            if resp.code == 401:
                context = self._auth.pre_request(req, resp)
                resp = await self._http_client.fetch(req, raise_error=False)
                self._auth.post_response(req, resp, context)
            self._cookie_jar.post_response(resp)
            if resp.error:
                if resp.code == 599:
                    raise TimeoutError("Request timed out")
                else:
                    try:
                        msg = json.loads(resp.body)["error"]
                    except Exception:
                        msg = resp.body.decode()

                    if resp.code in {404, 422}:
                        raise ValueError(msg)
                    elif resp.code == 409:
                        raise GatewayClusterError(msg)
                    elif resp.code == 500:
                        raise GatewayServerError(msg)
                    else:
                        resp.rethrow()
        except HTTPError as exc:
            # Tornado 6 still raises these above with raise_error=False
            if exc.code == 599:
                raise TimeoutError("Request timed out")
            # Should never get here!
            raise
        return resp

    async def _clusters(self, status=None):
        if status is not None:
            if isinstance(status, (str, ClusterStatus)):
                status = [ClusterStatus._create(status)]
            else:
                status = [ClusterStatus._create(s) for s in status]
            query = "?status=" + ",".join(s.name for s in status)
        else:
            query = ""

        url = "%s/gateway/api/clusters/%s" % (self.address, query)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return [
            ClusterReport._from_json(self.address, self.proxy_address, r)
            for r in json.loads(resp.body).values()
        ]

    def list_clusters(self, status=None, **kwargs):
        """List clusters for this user.

        Parameters
        ----------
        status : ClusterStatus, str, or list, optional
            The cluster status (or statuses) to select. Valid options are
            'starting', 'started', 'running', 'stopping', 'stopped', 'failed'.
            By default selects active clusters ('starting', 'started',
            'running').

        Returns
        -------
        clusters : list of ClusterReport
        """
        return self.sync(self._clusters, status=status, **kwargs)

    async def _cluster_options(self):
        url = "%s/gateway/api/clusters/options" % self.address
        req = HTTPRequest(url=url, method="GET")
        resp = await self._fetch(req)
        data = json.loads(resp.body)
        return Options._from_spec(data["cluster_options"])

    def cluster_options(self, **kwargs):
        """Get the available cluster configuration options.

        Returns
        -------
        cluster_options : Options
            A dict of cluster options.
        """
        return self.sync(self._cluster_options, **kwargs)

    async def _submit(self, cluster_options=None, **kwargs):
        url = "%s/gateway/api/clusters/" % self.address
        if cluster_options is not None:
            if not isinstance(cluster_options, Options):
                raise TypeError(
                    "cluster_options must be an `Options`, got %r" %
                    type(cluster_options).__name__)
            options = dict(cluster_options)
            options.update(kwargs)
        else:
            options = kwargs
        req = HTTPRequest(
            url=url,
            method="POST",
            body=json.dumps({"cluster_options": options}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        resp = await self._fetch(req)
        data = json.loads(resp.body)
        return data["name"]

    def submit(self, cluster_options=None, **kwargs):
        """Submit a new cluster to be started.

        This returns quickly with a ``cluster_name``, which can later be used
        to connect to the cluster.

        Parameters
        ----------
        cluster_options : Options, optional
            An ``Options`` object describing the desired cluster configuration.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._submit, **kwargs)

    async def _cluster_report(self, cluster_name, wait=False):
        params = "?wait" if wait else ""
        url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name,
                                                params)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return ClusterReport._from_json(self.address, self.proxy_address,
                                        json.loads(resp.body))

    async def _connect(self, cluster_name):
        while True:
            try:
                report = await self._cluster_report(cluster_name, wait=True)
            except TimeoutError:
                # Timeout, ignore
                pass
            else:
                if report.status is ClusterStatus.RUNNING:
                    return GatewayCluster(
                        gateway=self,
                        name=report.name,
                        scheduler_address=report.scheduler_address,
                        dashboard_link=report.dashboard_link,
                        security=report.security,
                    )
                elif report.status is ClusterStatus.FAILED:
                    raise GatewayClusterError(
                        "Cluster %r failed to start, see logs for "
                        "more information" % cluster_name)
                elif report.status is ClusterStatus.STOPPED:
                    raise GatewayClusterError("Cluster %r is already stopped" %
                                              cluster_name)
            # Not started yet, try again later
            await gen.sleep(0.5)

    def connect(self, cluster_name, **kwargs):
        """Connect to a submitted cluster.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._connect, cluster_name, **kwargs)

    async def _new_cluster(self, **kwargs):
        cluster_name = await self._submit(**kwargs)
        try:
            return await self._connect(cluster_name)
        except GatewayClusterError:
            raise
        except BaseException:
            # Ensure cluster is stopped on error
            await self._stop_cluster(cluster_name)
            raise

    def new_cluster(self, cluster_options=None, **kwargs):
        """Submit a new cluster to the gateway, and wait for it to be started.

        Same as calling ``submit`` and ``connect`` in one go.

        Parameters
        ----------
        cluster_options : Options, optional
            An ``Options`` object describing the desired cluster configuration.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._new_cluster,
                         cluster_options=cluster_options,
                         **kwargs)

    async def _stop_cluster(self, cluster_name):
        url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name)
        req = HTTPRequest(url=url, method="DELETE")
        await self._fetch(req)

    def stop_cluster(self, cluster_name, **kwargs):
        """Stop a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._stop_cluster, cluster_name, **kwargs)

    async def _scale_cluster(self, cluster_name, n):
        url = "%s/gateway/api/clusters/%s/workers" % (self.address,
                                                      cluster_name)
        req = HTTPRequest(
            url=url,
            method="PUT",
            body=json.dumps({"worker_count": n}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        await self._fetch(req)

    def scale_cluster(self, cluster_name, n, **kwargs):
        """Scale a cluster to n workers.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        n : int
            The number of workers to scale to.
        """
        return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
Ejemplo n.º 10
0
class Gateway(object):
    """A client for a Dask Gateway Server.

    Parameters
    ----------
    address : str, optional
        The address to the gateway server.
    auth : GatewayAuth, optional
        The authentication method to use.
    asynchronous : bool, optional
        If true, starts the client in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.
    """
    def __init__(self, address=None, auth=None, asynchronous=False, loop=None):
        if address is None:
            address = dask.config.get("gateway.address")
        if address is None:
            raise ValueError(
                "No dask-gateway address provided or found in configuration")
        self.address = address.rstrip("/")
        self._auth = get_auth(auth)
        self._cookie_jar = CookieJar()

        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        self._loop_runner.start()
        if self._asynchronous:
            self._started = self._start()
        else:
            self.sync(self._start)

    async def _start(self):
        self._http_client = AsyncHTTPClient()

    def close(self):
        """Close this gateway client"""
        if not self.asynchronous:
            self._loop_runner.stop()

    def __del__(self):
        # __del__ is still called, even if __init__ failed. Only close if init
        # actually succeeded.
        if hasattr(self, "_started"):
            self.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __await__(self):
        return self._started.__await__()

    async def __aenter__(self):
        await self._started
        return self

    async def __aexit__(self, typ, value, traceback):
        pass

    def __repr__(self):
        return "Gateway<%s>" % self.address

    @property
    def asynchronous(self):
        return (self._asynchronous
                or getattr(thread_state, "asynchronous", False)
                or (hasattr(self.loop, "_thread_identity")
                    and self.loop._thread_identity == get_ident()))

    def sync(self, func, *args, **kwargs):
        if kwargs.pop("asynchronous", None) or self.asynchronous:
            callback_timeout = kwargs.pop("callback_timeout", None)
            future = func(*args, **kwargs)
            if callback_timeout is not None:
                future = gen.with_timeout(timedelta(seconds=callback_timeout),
                                          future)
            return future
        else:
            return sync(self.loop, func, *args, **kwargs)

    async def _fetch(self, req, raise_error=True):
        self._cookie_jar.pre_request(req)
        resp = await self._http_client.fetch(req, raise_error=False)
        if resp.code == 401:
            context = self._auth.pre_request(req, resp)
            resp = await self._http_client.fetch(req, raise_error=False)
            self._auth.post_response(req, resp, context)
        self._cookie_jar.post_response(resp)
        if raise_error:
            resp.rethrow()
        return resp

    async def _clusters(self, status=None):
        if status is not None:
            if isinstance(status, (str, ClusterStatus)):
                status = [ClusterStatus._create(status)]
            else:
                status = [ClusterStatus._create(s) for s in status]
            query = "?status=" + ",".join(s.name for s in status)
        else:
            query = ""

        url = "%s/gateway/api/clusters/%s" % (self.address, query)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return [
            ClusterReport._from_json(self.address, r)
            for r in json.loads(resp.body).values()
        ]

    def list_clusters(self, status=None, **kwargs):
        """List clusters for this user.

        Parameters
        ----------
        status : ClusterStatus, str, or list, optional
            The cluster status (or statuses) to select. Valid options are
            'starting', 'started', 'running', 'stopping', 'stopped', 'failed'.
            By default selects active clusters ('starting', 'started',
            'running').

        Returns
        -------
        clusters : list of ClusterReport
        """
        return self.sync(self._clusters, status=status, **kwargs)

    async def _submit(self):
        url = "%s/gateway/api/clusters/" % self.address
        req = HTTPRequest(
            url=url,
            method="POST",
            body=json.dumps({}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        resp = await self._fetch(req)
        data = json.loads(resp.body)
        return data["name"]

    def submit(self, **kwargs):
        """Submit a new cluster to be started.

        This returns quickly with a ``cluster_name``, which can later be used
        to connect to the cluster.

        Returns
        -------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._submit, **kwargs)

    async def _cluster_report(self, cluster_name, wait=False):
        params = "?wait" if wait else ""
        url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name,
                                                params)
        req = HTTPRequest(url=url)
        resp = await self._fetch(req)
        return ClusterReport._from_json(self.address, json.loads(resp.body))

    async def _connect(self, cluster_name):
        while True:
            try:
                report = await self._cluster_report(cluster_name, wait=True)
            except HTTPError as exc:
                if exc.code == 404:
                    raise Exception("Unknown cluster %r" % cluster_name)
                elif exc.code == 599:
                    # Timeout, ignore
                    pass
                else:
                    raise
            else:
                if report.status is ClusterStatus.RUNNING:
                    return GatewayCluster(self, report)
                elif report.status is ClusterStatus.FAILED:
                    raise Exception("Cluster %r failed to start, see logs for "
                                    "more information" % cluster_name)
                elif report.status is ClusterStatus.STOPPED:
                    raise Exception("Cluster %r is already stopped" %
                                    cluster_name)
            # Not started yet, try again later
            await gen.sleep(0.5)

    def connect(self, cluster_name, **kwargs):
        """Connect to a submitted cluster.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._connect, cluster_name, **kwargs)

    async def _new_cluster(self, **kwargs):
        cluster_name = await self._submit(**kwargs)
        try:
            return await self._connect(cluster_name)
        except BaseException:
            # Ensure cluster is stopped on error
            await self._stop_cluster(cluster_name)
            raise

    def new_cluster(self, **kwargs):
        """Submit a new cluster to the gateway, and wait for it to be started.

        Same as calling ``submit`` and ``connect`` in one go.

        Returns
        -------
        cluster : GatewayCluster
        """
        return self.sync(self._new_cluster, **kwargs)

    async def _stop_cluster(self, cluster_name):
        url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name)
        req = HTTPRequest(url=url, method="DELETE")
        await self._fetch(req)

    def stop_cluster(self, cluster_name, **kwargs):
        """Stop a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._stop_cluster, cluster_name, **kwargs)

    async def _scale_cluster(self, cluster_name, n):
        url = "%s/gateway/api/clusters/%s/workers" % (self.address,
                                                      cluster_name)
        req = HTTPRequest(
            url=url,
            method="PUT",
            body=json.dumps({"worker_count": n}),
            headers=HTTPHeaders({"Content-type": "application/json"}),
        )
        try:
            await self._fetch(req)
        except HTTPError as exc:
            if exc.code == 409:
                raise Exception("Cluster %r is not running" % cluster_name)
            raise

    def scale_cluster(self, cluster_name, n, **kwargs):
        """Scale a cluster to n workers.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        n : int
            The number of workers to scale to.
        """
        return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
Ejemplo n.º 11
0
class Gateway(object):
    """A client for a Dask Gateway Server.

    Parameters
    ----------
    address : str, optional
        The address to the gateway server.
    proxy_address : str, int, optional
        The address of the scheduler proxy server. Defaults to `address` if not
        provided. If an int, it's used as the port, with the host/ip taken from
        ``address``. Provide a full address if a different host/ip should be
        used.
    auth : GatewayAuth, optional
        The authentication method to use.
    asynchronous : bool, optional
        If true, starts the client in asynchronous mode, where it can be used
        in other async code.
    loop : IOLoop, optional
        The IOLoop instance to use. Defaults to the current loop in
        asynchronous mode, otherwise a background loop is started.
    """
    def __init__(self,
                 address=None,
                 proxy_address=None,
                 auth=None,
                 asynchronous=False,
                 loop=None):
        if address is None:
            address = format_template(dask.config.get("gateway.address"))
        if address is None:
            raise ValueError(
                "No dask-gateway address provided or found in configuration")
        address = address.rstrip("/")

        public_address = format_template(
            dask.config.get("gateway.public-address"))
        if public_address is None:
            public_address = address
        else:
            public_address = public_address.rstrip("/")

        if proxy_address is None:
            proxy_address = format_template(
                dask.config.get("gateway.proxy-address"))
        if proxy_address is None:
            parsed = urlparse(address)
            if parsed.netloc:
                if parsed.port is None:
                    proxy_port = {
                        "http": 80,
                        "https": 443
                    }.get(parsed.scheme, 8786)
                else:
                    proxy_port = parsed.port
                proxy_netloc = "%s:%d" % (parsed.hostname, proxy_port)
            else:
                proxy_netloc = proxy_address
        elif isinstance(proxy_address, int):
            parsed = urlparse(address)
            proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address)
        elif isinstance(proxy_address, str):
            parsed = urlparse(proxy_address)
            proxy_netloc = parsed.netloc if parsed.netloc else proxy_address
        proxy_address = "gateway://%s" % proxy_netloc

        scheme = urlparse(address).scheme
        self._request_kwargs = _get_default_request_kwargs(scheme)

        self.address = address
        self._public_address = public_address
        self.proxy_address = proxy_address

        self.auth = get_auth(auth)
        self._session = None

        self._asynchronous = asynchronous
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self._loop_runner.start()

    @property
    def loop(self):
        return self._loop_runner.loop

    @property
    def asynchronous(self):
        return self._asynchronous

    def sync(self, func, *args, **kwargs):
        if self.asynchronous:
            return func(*args, **kwargs)
        else:
            future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs),
                                                      self.loop.asyncio_loop)
            try:
                return future.result()
            except BaseException:
                future.cancel()
                raise

    def close(self):
        """Close the gateway client"""
        if self.asynchronous:
            return self._cleanup()
        elif self.loop.asyncio_loop.is_running():
            self.sync(self._cleanup)
        self._loop_runner.stop()

    async def _cleanup(self):
        if self._session is not None:
            await self._session.close()
            self._session = None

    async def __aenter__(self):
        return self

    async def __aexit__(self, typ, value, traceback):
        await self._cleanup()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def __del__(self):
        if (not self.asynchronous and hasattr(self, "_loop_runner")
                and not sys.is_finalizing()):
            self.close()

    def __repr__(self):
        return "Gateway<%s>" % self.address

    async def _request(self, method, url, json=None):
        if self._session is None:
            self._session = aiohttp.ClientSession()
        session = self._session

        resp = await session.request(method,
                                     url,
                                     json=json,
                                     **self._request_kwargs)

        if resp.status == 401:
            headers, context = self.auth.pre_request(resp)
            resp = await session.request(method,
                                         url,
                                         json=json,
                                         headers=headers,
                                         **self._request_kwargs)
            self.auth.post_response(resp, context)

        if resp.status >= 400:
            try:
                msg = await resp.json()
                msg = msg["error"]
            except Exception:
                msg = await resp.text()

            if resp.status in {404, 422}:
                raise ValueError(msg)
            elif resp.status == 409:
                raise GatewayClusterError(msg)
            elif resp.status == 500:
                raise GatewayServerError(msg)
            else:
                resp.raise_for_status()
        else:
            return resp

    async def _clusters(self, status=None):
        if status is not None:
            if isinstance(status, (str, ClusterStatus)):
                status = [ClusterStatus._create(status)]
            else:
                status = [ClusterStatus._create(s) for s in status]
            query = "?status=" + ",".join(s.name for s in status)
        else:
            query = ""

        url = "%s/api/v1/clusters/%s" % (self.address, query)
        resp = await self._request("GET", url)
        data = await resp.json()
        return [
            ClusterReport._from_json(self._public_address, self.proxy_address,
                                     r) for r in data.values()
        ]

    def list_clusters(self, status=None, **kwargs):
        """List clusters for this user.

        Parameters
        ----------
        status : ClusterStatus, str, or list, optional
            The cluster status (or statuses) to select. Valid options are
            'pending', 'running', 'stopping', 'stopped', 'failed'.
            By default selects active clusters ('pending', 'running').

        Returns
        -------
        clusters : list of ClusterReport
        """
        return self.sync(self._clusters, status=status, **kwargs)

    def get_cluster(self, cluster_name, **kwargs):
        """Get information about a specific cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.

        Returns
        -------
        report : ClusterReport
        """
        return self.sync(self._cluster_report, cluster_name, **kwargs)

    async def _get_versions(self):
        url = "%s/api/version" % self.address
        resp = await self._request("GET", url)
        server_info = await resp.json()
        from . import __version__

        return {
            "server": server_info,
            "client": {
                "version": __version__
            },
        }

    def get_versions(self):
        """Return version info for the server and client

        Returns
        -------
        version_info : dict
        """
        return self.sync(self._get_versions)

    def _config_cluster_options(self):
        opts = dask.config.get("gateway.cluster.options")
        return {k: format_template(v) for k, v in opts.items()}

    async def _cluster_options(self, use_local_defaults=True):
        url = "%s/api/v1/options" % self.address
        resp = await self._request("GET", url)
        data = await resp.json()
        options = Options._from_spec(data["cluster_options"])
        if use_local_defaults:
            options.update(self._config_cluster_options())
        return options

    def cluster_options(self, use_local_defaults=True, **kwargs):
        """Get the available cluster configuration options.

        Parameters
        ----------
        use_local_defaults : bool, optional
            Whether to use any default options from the local configuration.
            Default is True, set to False to use only the server-side defaults.

        Returns
        -------
        cluster_options : dask_gateway.options.Options
            A dict of cluster options.
        """
        return self.sync(self._cluster_options,
                         use_local_defaults=use_local_defaults,
                         **kwargs)

    async def _submit(self, cluster_options=None, **kwargs):
        url = "%s/api/v1/clusters/" % self.address
        if cluster_options is not None:
            if not isinstance(cluster_options, Options):
                raise TypeError(
                    "cluster_options must be an `Options`, got %r" %
                    type(cluster_options).__name__)
            options = dict(cluster_options)
            options.update(kwargs)
        else:
            options = self._config_cluster_options()
            options.update(kwargs)
        resp = await self._request("POST",
                                   url,
                                   json={"cluster_options": options})
        data = await resp.json()
        return data["name"]

    def submit(self, cluster_options=None, **kwargs):
        """Submit a new cluster to be started.

        This returns quickly with a ``cluster_name``, which can later be used
        to connect to the cluster.

        Parameters
        ----------
        cluster_options : dask_gateway.options.Options, optional
            An ``Options`` object describing the desired cluster configuration.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._submit,
                         cluster_options=cluster_options,
                         **kwargs)

    async def _cluster_report(self, cluster_name, wait=False):
        params = "?wait" if wait else ""
        url = "%s/api/v1/clusters/%s%s" % (self.address, cluster_name, params)
        resp = await self._request("GET", url)
        data = await resp.json()
        return ClusterReport._from_json(self._public_address,
                                        self.proxy_address, data)

    async def _wait_for_start(self, cluster_name):
        while True:
            try:
                report = await self._cluster_report(cluster_name, wait=True)
            except TimeoutError:
                # Timeout, ignore
                pass
            else:
                if report.status is ClusterStatus.RUNNING:
                    return report
                elif report.status is ClusterStatus.FAILED:
                    raise GatewayClusterError(
                        "Cluster %r failed to start, see logs for "
                        "more information" % cluster_name)
                elif report.status is ClusterStatus.STOPPED:
                    raise GatewayClusterError("Cluster %r is already stopped" %
                                              cluster_name)
            # Not started yet, try again later
            await asyncio.sleep(0.5)

    def connect(self, cluster_name, shutdown_on_close=False):
        """Connect to a submitted cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster to connect to.
        shutdown_on_close : bool, optional
            If True, the cluster will be automatically shutdown on close.
            Default is False.

        Returns
        -------
        cluster : GatewayCluster
        """
        return GatewayCluster.from_name(
            cluster_name,
            shutdown_on_close=shutdown_on_close,
            address=self.address,
            proxy_address=self.proxy_address,
            auth=self.auth,
            asynchronous=self.asynchronous,
            loop=self.loop,
        )

    def new_cluster(self,
                    cluster_options=None,
                    shutdown_on_close=True,
                    **kwargs):
        """Submit a new cluster to the gateway, and wait for it to be started.

        Same as calling ``submit`` and ``connect`` in one go.

        Parameters
        ----------
        cluster_options : dask_gateway.options.Options, optional
            An ``Options`` object describing the desired cluster configuration.
        shutdown_on_close : bool, optional
            If True (default), the cluster will be automatically shutdown on
            close. Set to False to have cluster persist until explicitly
            shutdown.
        **kwargs :
            Additional cluster configuration options. If ``cluster_options`` is
            provided, these are applied afterwards as overrides. Available
            options are specific to each deployment of dask-gateway, see
            ``cluster_options`` for more information.

        Returns
        -------
        cluster : GatewayCluster
        """
        return GatewayCluster(
            address=self.address,
            proxy_address=self.proxy_address,
            auth=self.auth,
            asynchronous=self.asynchronous,
            loop=self.loop,
            cluster_options=cluster_options,
            shutdown_on_close=shutdown_on_close,
            **kwargs,
        )

    async def _stop_cluster(self, cluster_name):
        url = "%s/api/v1/clusters/%s" % (self.address, cluster_name)
        await self._request("DELETE", url)

    def stop_cluster(self, cluster_name, **kwargs):
        """Stop a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        """
        return self.sync(self._stop_cluster, cluster_name, **kwargs)

    async def _scale_cluster(self, cluster_name, n):
        url = "%s/api/v1/clusters/%s/scale" % (self.address, cluster_name)
        await self._request("POST", url, json={"count": n})

    def scale_cluster(self, cluster_name, n, **kwargs):
        """Scale a cluster to n workers.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        n : int
            The number of workers to scale to.
        """
        return self.sync(self._scale_cluster, cluster_name, n, **kwargs)

    async def _adapt_cluster(self,
                             cluster_name,
                             minimum=None,
                             maximum=None,
                             active=True):
        await self._request(
            "POST",
            "%s/api/v1/clusters/%s/adapt" % (self.address, cluster_name),
            json={
                "minimum": minimum,
                "maximum": maximum,
                "active": active
            },
        )

    def adapt_cluster(self,
                      cluster_name,
                      minimum=None,
                      maximum=None,
                      active=True,
                      **kwargs):
        """Configure adaptive scaling for a cluster.

        Parameters
        ----------
        cluster_name : str
            The cluster name.
        minimum : int, optional
            The minimum number of workers to scale to. Defaults to 0.
        maximum : int, optional
            The maximum number of workers to scale to. Defaults to infinity.
        active : bool, optional
            If ``True`` (default), adaptive scaling is activated. Set to
            ``False`` to deactivate adaptive scaling.
        """
        return self.sync(
            self._adapt_cluster,
            cluster_name,
            minimum=minimum,
            maximum=maximum,
            active=active,
            **kwargs,
        )
Ejemplo n.º 12
0
class AzureMLCluster(Cluster):
    """ Deploy a Dask cluster using Azure ML

    This creates a dask scheduler and workers on an Azure ML Compute Target.

    Parameters
    ----------
    workspace: azureml.core.Workspace (required)
        Azure ML Workspace - see https://aka.ms/azureml/workspace

    compute_target: azureml.core.ComputeTarget (required)
        Azure ML Compute Target - see https://aka.ms/azureml/computetarget

    environment_definition: azureml.core.Environment (required)
        Azure ML Environment - see https://aka.ms/azureml/environments

    experiment_name: str (optional)
        The name of the Azure ML Experiment used to control the cluster.

        Defaults to ``dask-cloudprovider``.

    initial_node_count: int (optional)
        The initial number of nodes for the Dask Cluster.

        Defaults to ``1``.

    jupyter: bool (optional)
        Flag to start JupyterLab session on the headnode of the cluster.

        Defaults to ``False``.

    jupyter_port: int (optional)
        Port on headnode to use for hosting JupyterLab session.

        Defaults to ``9000``.

    dashboard_port: int (optional)
        Port on headnode to use for hosting Dask dashboard.

        Defaults to ``9001``.

    scheduler_port: int (optional)
        Port to map the scheduler port to via SSH-tunnel if machine not on the same VNET.

        Defaults to ``9002``.

    additional_ports: list[tuple[int, int]] (optional)
        Additional ports to forward. This requires a list of tuples where the first element
        is the port to open on the headnode while the second element is the port to map to
        or forward via the SSH-tunnel.

        Defaults to ``[]``.

    admin_username: str (optional)
        Username of the admin account for the AzureML Compute.
        Required for runs that are not on the same VNET. Defaults to empty string.
        Throws Exception if machine not on the same VNET.

        Defaults to ``""``.

    admin_ssh_key: str (optional)
        Location of the SSH secret key used when creating the AzureML Compute.
        The key should be passwordless if run from a Jupyter notebook.
        The ``id_rsa`` file needs to have 0700 permissions set.
        Required for runs that are not on the same VNET. Defaults to empty string.
        Throws Exception if machine not on the same VNET.

        Defaults to ``""``.

    datastores: List[str] (optional)
        List of Azure ML Datastores to be mounted on the headnode -
        see https://aka.ms/azureml/data and https://aka.ms/azureml/datastores.

        Defaults to ``[]``. To mount all datastores in the workspace,
        set to ``[ws.datastores[datastore] for datastore in ws.datastores]``.

    asynchronous: bool (optional)
        Flag to run jobs asynchronously.

    **kwargs: dict
        Additional keyword arguments.
    """

    def __init__(
        self,
        workspace,
        compute_target,
        environment_definition,
        experiment_name=None,
        initial_node_count=None,
        jupyter=None,
        jupyter_port=None,
        dashboard_port=None,
        scheduler_port=None,
        scheduler_idle_timeout=None,
        worker_death_timeout=None,
        additional_ports=None,
        admin_username=None,
        admin_ssh_key=None,
        datastores=None,
        code_store=None,
        asynchronous=False,
        **kwargs,
    ):
        ### REQUIRED PARAMETERS
        self.workspace = workspace
        self.compute_target = compute_target
        self.environment_definition = environment_definition

        ### EXPERIMENT DEFINITION
        self.experiment_name = experiment_name

        ### ENVIRONMENT AND VARIABLES
        self.initial_node_count = initial_node_count

        ### GPU RUN INFO
        self.workspace_vm_sizes = AmlCompute.supported_vmsizes(self.workspace)
        self.workspace_vm_sizes = [
            (e["name"].lower(), e["gpus"]) for e in self.workspace_vm_sizes
        ]
        self.workspace_vm_sizes = dict(self.workspace_vm_sizes)

        self.compute_target_vm_size = self.compute_target.serialize()["properties"][
            "status"
        ]["vmSize"].lower()
        self.n_gpus_per_node = self.workspace_vm_sizes[self.compute_target_vm_size]
        self.use_gpu = True if self.n_gpus_per_node > 0 else False

        ### JUPYTER AND PORT FORWARDING
        self.jupyter = jupyter
        self.jupyter_port = jupyter_port
        self.dashboard_port = dashboard_port
        self.scheduler_port = scheduler_port
        self.scheduler_idle_timeout = scheduler_idle_timeout
        self.worker_death_timeout = worker_death_timeout

        if additional_ports is not None:
            if type(additional_ports) != list:
                error_message = (
                    f"The additional_ports parameter is of {type(additional_ports)}"
                    " type but needs to be a list of int tuples."
                    " Check the documentation."
                )
                logger.exception(error_message)
                raise TypeError(error_message)

            if len(additional_ports) > 0:
                if type(additional_ports[0]) != tuple:
                    error_message = (
                        f"The additional_ports elements are of {type(additional_ports[0])}"
                        " type but needs to be a list of int tuples."
                        " Check the documentation."
                    )
                    raise TypeError(error_message)

                ### check if all elements are tuples of length two and int type
                all_correct = True
                for el in additional_ports:
                    if type(el) != tuple or len(el) != 2:
                        all_correct = False
                        break

                    if (type(el[0]), type(el[1])) != (int, int):
                        all_correct = False
                        break

                if not all_correct:
                    error_message = (
                        f"At least one of the elements of the additional_ports parameter"
                        " is wrong. Make sure it is a list of int tuples."
                        " Check the documentation."
                    )
                    raise TypeError(error_message)

        self.additional_ports = additional_ports

        self.admin_username = admin_username
        self.admin_ssh_key = admin_ssh_key
        self.scheduler_ip_port = (
            None  ### INIT FOR HOLDING THE ADDRESS FOR THE SCHEDULER
        )

        ### DATASTORES
        self.datastores = datastores

        ### FUTURE EXTENSIONS
        self.kwargs = kwargs

        ### RUNNING IN MATRIX OR LOCAL
        self.same_vnet = None

        ### GET RUNNING LOOP
        self._loop_runner = LoopRunner(loop=None, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop

        self.abs_path = pathlib.Path(__file__).parent.absolute()

        ### INITIALIZE CLUSTER
        super().__init__(asynchronous=asynchronous)

        if not self.asynchronous:
            self._loop_runner.start()
            self.sync(self.__get_defaults)
            self.sync(self.__create_cluster)

    async def __get_defaults(self):
        self.config = dask.config.get("cloudprovider.azure", {})

        if self.experiment_name is None:
            self.experiment_name = self.config.get("experiment_name")

        if self.initial_node_count is None:
            self.initial_node_count = self.config.get("initial_node_count")

        if self.jupyter is None:
            self.jupyter = self.config.get("jupyter")

        if self.jupyter_port is None:
            self.jupyter_port = self.config.get("jupyter_port")

        if self.dashboard_port is None:
            self.dashboard_port = self.config.get("dashboard_port")

        if self.scheduler_port is None:
            self.scheduler_port = self.config.get("scheduler_port")

        if self.scheduler_idle_timeout is None:
            self.scheduler_idle_timeout = self.config.get("scheduler_idle_timeout")

        if self.worker_death_timeout is None:
            self.worker_death_timeout = self.config.get("worker_death_timeout")

        if self.additional_ports is None:
            self.additional_ports = self.config.get("additional_ports")

        if self.admin_username is None:
            self.admin_username = self.config.get("admin_username")

        if self.admin_ssh_key is None:
            self.admin_ssh_key = self.config.get("admin_ssh_key")

        if self.datastores is None:
            self.datastores = self.config.get("datastores")

        ### PARAMETERS TO START THE CLUSTER
        self.scheduler_params = {}
        self.worker_params = {}

        ### scheduler and worker parameters
        self.scheduler_params["--jupyter"] = self.jupyter
        self.scheduler_params["--scheduler_idle_timeout"] = self.scheduler_idle_timeout
        self.worker_params["--worker_death_timeout"] = self.worker_death_timeout

        if self.use_gpu:
            self.scheduler_params["--use_gpu"] = True
            self.scheduler_params["--n_gpus_per_node"] = self.n_gpus_per_node
            self.worker_params["--use_gpu"] = True
            self.worker_params["--n_gpus_per_node"] = self.n_gpus_per_node

        ### CLUSTER PARAMS
        self.max_nodes = self.compute_target.serialize()["properties"]["properties"][
            "scaleSettings"
        ]["maxNodeCount"]
        self.scheduler_ip_port = None
        self.workers_list = []
        self.URLs = {}

        ### SANITY CHECKS
        ###-----> initial node count
        if self.initial_node_count > self.max_nodes:
            self.initial_node_count = self.max_nodes

    def __print_message(self, msg, length=80, filler="#", pre_post=""):
        logger.info(msg)
        print(f"{pre_post} {msg} {pre_post}".center(length, filler))

    async def __check_if_scheduler_ip_reachable(self):
        """
        Private method to determine if running in the cloud within the same VNET
        and the scheduler node is reachable
        """
        try:
            ip, port = self.scheduler_ip_port.split(":")
            socket.create_connection((ip, port), 10)
            self.same_vnet = True
            self.__print_message("On the same VNET")
            logger.info("On the same VNET")
        except socket.timeout as e:

            self.__print_message("Not on the same VNET")
            logger.info("On the same VNET")
            self.same_vnet = False
        except ConnectionRefusedError as e:
            logger.info(e)
            pass

    def __prepare_rpc_connection_to_headnode(self):
        if not self.same_vnet:
            if self.admin_username == "" or self.admin_ssh_key == "":
                message = "Your machine is not at the same VNET as the cluster. "
                message += "You need to set admin_username and admin_ssh_key. Check documentation."
                logger.exception(message)
                raise Exception(message)
            else:
                uri = f"{socket.gethostname()}:{self.scheduler_port}"
                logger.info(f"Local connection: {uri}")
                return uri
        else:
            return self.run.get_metrics()["scheduler"]

    async def __create_cluster(self):
        # set up environment
        self.__print_message("Setting up cluster")

        # submit run
        self.__print_message("Submitting the experiment")

        exp = Experiment(self.workspace, self.experiment_name)
        estimator = Estimator(
            os.path.join(self.abs_path, "setup"),
            compute_target=self.compute_target,
            entry_script="start_scheduler.py",
            environment_definition=self.environment_definition,
            script_params=self.scheduler_params,
            node_count=1,  ### start only scheduler
            distributed_training=MpiConfiguration(),
            use_docker=True,
            inputs=self.datastores,
        )

        run = exp.submit(estimator)

        self.__print_message("Waiting for scheduler node's IP")

        while (
            run.get_status() != "Canceled"
            and run.get_status() != "Failed"
            and "scheduler" not in run.get_metrics()
        ):
            print(".", end="")
            logger.info("Scheduler not ready")
            time.sleep(5)

        if run.get_status() == "Canceled" or run.get_status() == "Failed":
            logger.exception("Failed to start the AzureML cluster")
            raise Exception("Failed to start the AzureML cluster.")

        print("\n\n")

        ### SET FLAGS
        self.scheduler_ip_port = run.get_metrics()["scheduler"]
        self.worker_params["--scheduler_ip_port"] = self.scheduler_ip_port
        self.__print_message(f'Scheduler: {run.get_metrics()["scheduler"]}')
        self.run = run

        logger.info(f'Scheduler: {run.get_metrics()["scheduler"]}')

        ### CHECK IF ON THE SAME VNET
        while self.same_vnet is None:
            await self.sync(self.__check_if_scheduler_ip_reachable)
            time.sleep(1)

        ### REQUIRED BY dask.distributed.deploy.cluster.Cluster
        _scheduler = self.__prepare_rpc_connection_to_headnode()
        self.scheduler_comm = rpc(_scheduler)
        await self.sync(self.__setup_port_forwarding)
        await self.sync(super()._start)
        await self.sync(self.__update_links)

        self.__print_message("Connections established")
        self.__print_message(f"Scaling to {self.initial_node_count} workers")

        if self.initial_node_count > 1:
            self.scale(
                self.initial_node_count
            )  # LOGIC TO KEEP PROPER TRACK OF WORKERS IN `scale`
        self.__print_message(f"Scaling is done")

    async def __update_links(self):
        hostname = socket.gethostname()
        location = self.workspace.get_details()["location"]
        token = self.run.get_metrics()["token"]

        if self.same_vnet:
            self.scheduler_info[
                "dashboard_url"
            ] = f"https://{hostname}-{self.dashboard_port}.{location}.instances.azureml.net/status"

            self.scheduler_info[
                "jupyter_url"
            ] = f"https://{hostname}-{self.jupyter_port}.{location}.instances.azureml.net/lab?token={token}"
        else:
            self.scheduler_info[
                "dashboard_url"
            ] = f"http://{hostname}:{self.dashboard_port}"
            self.scheduler_info[
                "jupyter_url"
            ] = f"http://{hostname}:{self.jupyter_port}/?token={token}"

        logger.info(f'Dashboard URL: {self.scheduler_info["dashboard_url"]}')
        logger.info(f'Jupyter URL:   {self.scheduler_info["jupyter_url"]}')

    async def __setup_port_forwarding(self):
        dashboard_address = self.run.get_metrics()["dashboard"]
        jupyter_address = self.run.get_metrics()["jupyter"]
        scheduler_ip = self.run.get_metrics()["scheduler"].split(":")[0]

        if self.same_vnet:
            os.system(
                f"killall socat"
            )  # kill all socat processes - cleans up previous port forward setups
            os.system(
                f"setsid socat tcp-listen:{self.dashboard_port},reuseaddr,fork tcp:{dashboard_address} &"
            )
            os.system(
                f"setsid socat tcp-listen:{self.jupyter_port},reuseaddr,fork tcp:{jupyter_address} &"
            )

            ### map additional ports
            for port in self.additional_ports:
                os.system(
                    f"setsid socat tcp-listen:{self.port[1]},reuseaddr,fork tcp:{scheduler_ip}:{port[0]} &"
                )
        else:
            scheduler_public_ip = self.compute_target.list_nodes()[0]["publicIpAddress"]
            scheduler_public_port = self.compute_target.list_nodes()[0]["port"]

            cmd = (
                "ssh -vvv -o StrictHostKeyChecking=no -N"
                f" -i {self.admin_ssh_key}"
                f" -L 0.0.0.0:{self.jupyter_port}:{scheduler_ip}:8888"
                f" -L 0.0.0.0:{self.dashboard_port}:{scheduler_ip}:8787"
                f" -L 0.0.0.0:{self.scheduler_port}:{scheduler_ip}:8786"
            )

            for port in self.additional_ports:
                cmd += f" -L 0.0.0.0:{port[1]}:{scheduler_ip}:{port[0]}"

            cmd += f" {self.admin_username}@{scheduler_public_ip} -p {scheduler_public_port}"

            portforward_log = open("portforward_out_log.txt", "w")
            portforward_proc = subprocess.Popen(
                cmd.split(),
                universal_newlines=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
            )

    @property
    def dashboard_link(self):
        """ Link to Dask dashboard.
        """
        try:
            link = self.scheduler_info["dashboard_url"]
        except KeyError:
            return ""
        else:
            return link

    @property
    def jupyter_link(self):
        """ Link to JupyterLab on running on the headnode of the cluster.
        Set ``jupyter=True`` when creating the ``AzureMLCluster``.
        """
        try:
            link = self.scheduler_info["jupyter_url"]
        except KeyError:
            return ""
        else:
            return link

    def _format_nodes(self, nodes, requested, use_gpu, n_gpus_per_node=None):
        if use_gpu:
            if nodes == requested:
                return f"{nodes}"
            else:
                return f"{nodes} / {requested}"
        else:
            if nodes == requested:
                return f"{nodes}"
            else:
                return f"{nodes} / {requested}"

    def _widget_status(self):
        ### reporting proper number of nodes vs workers in a multi-GPU worker scenario
        nodes = len(self.scheduler_info["workers"])

        if self.use_gpu:
            nodes = int(nodes / self.n_gpus_per_node)
        if hasattr(self, "worker_spec"):
            requested = sum(
                1 if "group" not in each else len(each["group"])
                for each in self.worker_spec.values()
            )

        elif hasattr(self, "nodes"):
            requested = len(self.nodes)
        else:
            requested = nodes

        nodes = self._format_nodes(nodes, requested, self.use_gpu, self.n_gpus_per_node)

        cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values())
        cores_or_gpus = "Workers (GPUs)" if self.use_gpu else "Workers (vCPUs)"

        memory = (
            sum(
                v["gpu"]["memory-total"][0]
                for v in self.scheduler_info["workers"].values()
            )
            if self.use_gpu
            else sum(v["memory_limit"] for v in self.scheduler_info["workers"].values())
        )
        memory = format_bytes(memory)

        text = """
<div>
  <style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }
    .dataframe tbody tr th {
        vertical-align: top;
    }
    .dataframe thead th {
        text-align: right;
    }
  </style>
  <table style="text-align: right;">
    <tr> <th>Nodes</th> <td>%s</td></tr>
    <tr> <th>%s</th> <td>%s</td></tr>
    <tr> <th>Memory</th> <td>%s</td></tr>
  </table>
</div>
""" % (
            nodes,
            cores_or_gpus,
            cores,
            memory,
        )
        return text

    def _widget(self):
        """ Create IPython widget for display within a notebook """
        try:
            return self._cached_widget
        except AttributeError:
            pass

        try:
            from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion
        except ImportError:
            self._cached_widget = None
            return None

        layout = Layout(width="150px")

        if self.dashboard_link:
            dashboard_link = (
                '<p><b>Dashboard: </b><a href="%s" target="_blank">%s</a></p>\n'
                % (self.dashboard_link, self.dashboard_link,)
            )
        else:
            dashboard_link = ""

        if self.jupyter_link:
            jupyter_link = (
                '<p><b>Jupyter: </b><a href="%s" target="_blank">%s</a></p>\n'
                % (self.jupyter_link, self.jupyter_link,)
            )
        else:
            jupyter_link = ""

        title = "<h2>%s</h2>" % self._cluster_class_name
        title = HTML(title)
        dashboard = HTML(dashboard_link)
        jupyter = HTML(jupyter_link)

        status = HTML(self._widget_status(), layout=Layout(min_width="150px"))

        if self._supports_scaling:
            request = IntText(
                self.initial_node_count, description="Nodes", layout=layout
            )
            scale = Button(description="Scale", layout=layout)

            minimum = IntText(0, description="Minimum", layout=layout)
            maximum = IntText(0, description="Maximum", layout=layout)
            adapt = Button(description="Adapt", layout=layout)

            accordion = Accordion(
                [HBox([request, scale]), HBox([minimum, maximum, adapt])],
                layout=Layout(min_width="500px"),
            )
            accordion.selected_index = None
            accordion.set_title(0, "Manual Scaling")
            accordion.set_title(1, "Adaptive Scaling")

            def adapt_cb(b):
                self.adapt(minimum=minimum.value, maximum=maximum.value)
                update()

            adapt.on_click(adapt_cb)

            def scale_cb(b):
                with log_errors():
                    n = request.value
                    with ignoring(AttributeError):
                        self._adaptive.stop()
                    self.scale(n)
                    update()

            scale.on_click(scale_cb)
        else:
            accordion = HTML("")

        box = VBox([title, HBox([status, accordion]), jupyter, dashboard])

        self._cached_widget = box

        def update():
            self.close_when_disconnect()
            status.value = self._widget_status()

        pc = PeriodicCallback(update, 500, io_loop=self.loop)
        self.periodic_callbacks["cluster-repr"] = pc
        pc.start()

        return box

    def close_when_disconnect(self):
        if (
            self.run.get_status() == "Canceled"
            or self.run.get_status() == "Completed"
            or self.run.get_status() == "Failed"
        ):
            self.scale_down(len(self.workers_list))

    def scale(self, workers=1):
        """ Scale the cluster. Scales to a maximum of the workers available in the cluster.
        """
        if workers <= 0:
            self.close()
            return

        count = len(self.workers_list) + 1  # one more worker in head node

        if count < workers:
            self.scale_up(workers - count)
        elif count > workers:
            self.scale_down(count - workers)
        else:
            self.__print_message(f"Number of workers: {workers}")

    # scale up
    def scale_up(self, workers=1):
        """ Scale up the number of workers.
        """
        run_config = RunConfiguration()
        run_config.target = self.compute_target
        run_config.environment = self.environment_definition

        scheduler_ip = self.run.get_metrics()["scheduler"]
        args = [
            f"--scheduler_ip_port={scheduler_ip}",
            f"--use_gpu={self.use_gpu}",
            f"--n_gpus_per_node={self.n_gpus_per_node}",
            f"--worker_death_timeout={self.worker_death_timeout}",
        ]

        child_run_config = ScriptRunConfig(
            source_directory=os.path.join(self.abs_path, "setup"),
            script="start_worker.py",
            arguments=args,
            run_config=run_config,
        )

        for i in range(workers):
            child_run = self.run.submit_child(child_run_config)
            self.workers_list.append(child_run)

    # scale down
    def scale_down(self, workers=1):
        """ Scale down the number of workers. Scales to minimum of 1.
        """
        for i in range(workers):
            if self.workers_list:
                child_run = self.workers_list.pop(0)  # deactive oldest workers
                child_run.complete()  # complete() will mark the run "Complete", but won't kill the process
                child_run.cancel()
            else:
                self.__print_message("All scaled workers are removed.")

    # close cluster
    async def _close(self):
        if self.status == "closed":
            return
        while self.workers_list:
            child_run = self.workers_list.pop()
            child_run.complete()
            child_run.cancel()

        if self.run:
            self.run.complete()
            self.run.cancel()

        await super()._close()
        self.status = "closed"
        self.__print_message("Scheduler and workers are disconnected.")

    def close(self):
        """ Close the cluster. All Azure ML Runs corresponding to the scheduler
        and worker processes will be completed. The Azure ML Compute Target will
        return to its minimum number of nodes after its idle time before scaledown.
        """
        return self.sync(self._close)
Ejemplo n.º 13
0
class HelmCluster(Cluster):
    """Connect to a Dask cluster deployed via the Helm Chart.

    This cluster manager connects to an existing Dask deployment that was
    created by the Dask Helm Chart. Enabling you to perform basic cluster actions
    such as scaling and log retrieval.

    Parameters
    ----------
    release_name: str
        Name of the helm release to connect to.
    namespace: str (optional)
        Namespace in which to launch the workers.
        Defaults to current namespace if available or "default"
    port_forward_cluster_ip: bool (optional)
        If the chart uses ClusterIP type services, forward the ports locally.
        If you are using ``HelmCluster`` from the Jupyter session that was installed
        by the helm chart this should be ``False``. If you are running it locally it should
        be ``True``.
    auth: List[ClusterAuth] (optional)
        Configuration methods to attempt in order.  Defaults to
        ``[InCluster(), KubeConfig()]``.
    scheduler_name: str (optional)
        Name of the Dask scheduler deployment in the current release.
        Defaults to "scheduler".
    worker_name: str (optional)
        Name of the Dask worker deployment in the current release.
        Defaults to "worker".
    **kwargs: dict
        Additional keyword arguments to pass to Cluster

    Examples
    --------
    >>> from dask_kubernetes import HelmCluster
    >>> cluster = HelmCluster(release_name="myhelmrelease")

    You can then resize the cluster with the scale method

    >>> cluster.scale(10)

    You can pass this cluster directly to a Dask client

    >>> from dask.distributed import Client
    >>> client = Client(cluster)

    You can also access cluster logs

    >>> cluster.get_logs()

    See Also
    --------
    HelmCluster.scale
    HelmCluster.logs
    """
    def __init__(
        self,
        release_name=None,
        auth=ClusterAuth.DEFAULT,
        namespace=None,
        port_forward_cluster_ip=False,
        loop=None,
        asynchronous=False,
        scheduler_name="scheduler",
        worker_name="worker",
    ):
        self.release_name = release_name
        self.namespace = namespace or _namespace_default()
        self.check_helm_dependency()
        status = subprocess.run(
            ["helm", "-n", self.namespace, "status", self.release_name],
            capture_output=True,
            encoding="utf-8",
        )
        if status.returncode != 0:
            raise RuntimeError(f"No such helm release {self.release_name}.")
        self.auth = auth
        self.namespace
        self.core_api = None
        self.scheduler_comm = None
        self.port_forward_cluster_ip = port_forward_cluster_ip
        self._supports_scaling = True
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop
        self.scheduler_name = scheduler_name
        self.worker_name = worker_name

        super().__init__(asynchronous=asynchronous)
        if not self.asynchronous:
            self._loop_runner.start()
            self.sync(self._start)

    @staticmethod
    def check_helm_dependency():
        if shutil.which("helm") is None:
            raise RuntimeError(
                "Missing dependency helm. "
                "Please install helm following the instructions for your OS. "
                "https://helm.sh/docs/intro/install/")

    async def _start(self):
        await ClusterAuth.load_first(self.auth)
        self.core_api = kubernetes.client.CoreV1Api()
        self.apps_api = kubernetes.client.AppsV1Api()
        self.scheduler_comm = rpc(await self._get_scheduler_address())
        await super()._start()

    async def _get_scheduler_address(self):
        service_name = f"{self.release_name}-{self.scheduler_name}"
        service = await self.core_api.read_namespaced_service(
            service_name, self.namespace)
        [port] = [
            port.port for port in service.spec.ports
            if port.name == service_name
        ]
        if service.spec.type == "LoadBalancer":
            lb = service.status.load_balancer.ingress[0]
            host = lb.hostname or lb.ip
            return f"tcp://{host}:{port}"
        elif service.spec.type == "NodePort":
            nodes = await self.core_api.list_node()
            host = nodes.items[0].status.addresses[0].address
            return f"tcp://{host}:{port}"
        elif service.spec.type == "ClusterIP":
            if self.port_forward_cluster_ip:
                warnings.warn(
                    f"""
                    Sorry we do not currently support local port forwarding.

                    Please port-forward the service locally yourself with the following command.

                    kubectl port-forward --namespace {self.namespace} svc/{service_name} {port}:{port} &
                    """
                )  # FIXME Handle this port forward here with the kubernetes library
                return f"tcp://localhost:{port}"
            return f"tcp://{service.spec.cluster_ip}:{port}"
        raise RuntimeError("Unable to determine scheduler address.")

    async def _wait_for_workers(self):
        while True:
            n_workers = len(self.scheduler_info["workers"])
            deployment = await self.apps_api.read_namespaced_deployment(
                name=f"{self.release_name}-{self.worker_name}",
                namespace=self.namespace)
            deployment_replicas = deployment.spec.replicas
            if n_workers == deployment_replicas:
                return
            else:
                await asyncio.sleep(0.2)

    def get_logs(self):
        """Get logs for Dask scheduler and workers.

        Examples
        --------
        >>> cluster.get_logs()
        {'testdask-scheduler-5c8ffb6b7b-sjgrg': ...,
        'testdask-worker-64c8b78cc-992z8': ...,
        'testdask-worker-64c8b78cc-hzpdc': ...,
        'testdask-worker-64c8b78cc-wbk4f': ...}

        Each log will be a string of all logs for that container. To view
        it is recommeded that you print each log.

        >>> print(cluster.get_logs()["testdask-scheduler-5c8ffb6b7b-sjgrg"])
        ...
        distributed.scheduler - INFO - -----------------------------------------------
        distributed.scheduler - INFO - Clear task state
        distributed.scheduler - INFO -   Scheduler at:     tcp://10.1.6.131:8786
        distributed.scheduler - INFO -   dashboard at:                     :8787
        ...
        """
        return self.sync(self._get_logs)

    async def _get_logs(self):
        logs = Logs()

        pods = await self.core_api.list_namespaced_pod(
            namespace=self.namespace,
            label_selector=f"release={self.release_name},app=dask",
        )

        for pod in pods.items:
            if "scheduler" in pod.metadata.name or "worker" in pod.metadata.name:
                logs[pod.metadata.name] = Log(
                    await self.core_api.read_namespaced_pod_log(
                        pod.metadata.name, pod.metadata.namespace))

        return logs

    def __await__(self):
        async def _():
            if self.status == "created":
                await self._start()
            elif self.status == "running":
                await self._wait_for_workers()
            return self

        return _().__await__()

    def scale(self, n_workers):
        """Scale cluster to n workers.

        This sets the Dask worker deployment size to the requested number.
        Workers will not be terminated gracefull so be sure to only scale down
        when all futures have been retrieved by the client and the cluster is idle.

        Examples
        --------

        >>> cluster
        HelmCluster('tcp://localhost:8786', workers=3, threads=18, memory=18.72 GB)
        >>> cluster.scale(4)
        >>> cluster
        HelmCluster('tcp://localhost:8786', workers=4, threads=24, memory=24.96 GB)

        """
        return self.sync(self._scale, n_workers)

    async def _scale(self, n_workers):
        await self.apps_api.patch_namespaced_deployment(
            name=f"{self.release_name}-{self.worker_name}",
            namespace=self.namespace,
            body={"spec": {
                "replicas": n_workers,
            }},
        )

    def adapt(self, *args, **kwargs):
        """Turn on adaptivity (Not recommended)."""
        raise NotImplementedError(
            "It is not recommended to run ``HelmCluster`` in adaptive mode."
            "When scaling down workers the decision on which worker to remove is left to Kubernetes, which"
            "will not necessarily remove the same worker that Dask would choose. This may result in lost futures and"
            "recalculation. It is recommended to manage scaling yourself with the ``HelmCluster.scale`` method."
        )

    async def _adapt(self, *args, **kwargs):
        return super().adapt(*args, **kwargs)
Ejemplo n.º 14
0
def test_two_loop_runners(loop_in_thread):
    # Loop runners tied to the same loop should cooperate

    # ABCCBA
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    b.start()
    assert_running(loop)
    c.start()
    assert_running(loop)
    c.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    a.stop()
    assert_not_running(loop)

    # ABCABC
    loop = IOLoop()
    a = LoopRunner(loop=loop)
    b = LoopRunner(loop=loop)
    assert_not_running(loop)
    a.start()
    assert_running(loop)
    b.start()
    assert_running(loop)
    c = LoopRunner(loop=loop)
    c.start()
    assert_running(loop)
    a.stop()
    assert_running(loop)
    b.stop()
    assert_running(loop)
    c.stop()
    assert_not_running(loop)

    # Explicit loop, already started
    a = LoopRunner(loop=loop_in_thread)
    b = LoopRunner(loop=loop_in_thread)
    assert_running(loop_in_thread)
    a.start()
    assert_running(loop_in_thread)
    b.start()
    assert_running(loop_in_thread)
    a.stop()
    assert_running(loop_in_thread)
    b.stop()
    assert_running(loop_in_thread)
Ejemplo n.º 15
0
def test_loop_runner(loop_in_thread):
    # Implicit loop
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner()
    assert runner.loop not in (loop, loop_in_thread)
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop
    loop = IOLoop()
    runner = LoopRunner(loop=loop)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(loop)
    runner.start()
    assert runner.is_started()
    assert_running(loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(loop)

    # Explicit loop, already started
    runner = LoopRunner(loop=loop_in_thread)
    assert not runner.is_started()
    assert_running(loop_in_thread)
    runner.start()
    assert runner.is_started()
    assert_running(loop_in_thread)
    runner.stop()
    assert not runner.is_started()
    assert_running(loop_in_thread)

    # Implicit loop, asynchronous=True
    loop = IOLoop()
    loop.make_current()
    runner = LoopRunner(asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)

    # Explicit loop, asynchronous=True
    loop = IOLoop()
    runner = LoopRunner(loop=loop, asynchronous=True)
    assert runner.loop is loop
    assert not runner.is_started()
    assert_not_running(runner.loop)
    runner.start()
    assert runner.is_started()
    assert_not_running(runner.loop)
    runner.stop()
    assert not runner.is_started()
    assert_not_running(runner.loop)
Ejemplo n.º 16
0
class HelmCluster(Cluster):
    """Connect to a Dask cluster deployed via the Helm Chart.

    This cluster manager connects to an existing Dask deployment that was
    created by the Dask Helm Chart. Enabling you to perform basic cluster actions
    such as scaling and log retrieval.

    Parameters
    ----------
    release_name: str
        Name of the helm release to connect to.
    namespace: str (optional)
        Namespace in which to launch the workers.
        Defaults to current namespace if available or "default"
    port_forward_cluster_ip: bool (optional)
        If the chart uses ClusterIP type services, forward the ports locally.
        If you are using ``HelmCluster`` from the Jupyter session that was installed
        by the helm chart this should be ``False``. If you are running it locally it should
        be the port you are forwarding to ``<port>``.
    auth: List[ClusterAuth] (optional)
        Configuration methods to attempt in order.  Defaults to
        ``[InCluster(), KubeConfig()]``.
    scheduler_name: str (optional)
        Name of the Dask scheduler deployment in the current release.
        Defaults to "scheduler".
    worker_name: str (optional)
        Name of the Dask worker deployment in the current release.
        Defaults to "worker".
    node_host: str (optional)
        A node address. Can be provided in case scheduler service type is
        ``NodePort`` and you want to manually specify which node to connect to.
    node_port: int (optional)
        A node address. Can be provided in case scheduler service type is
        ``NodePort`` and you want to manually specify which port to connect to.
    **kwargs: dict
        Additional keyword arguments to pass to Cluster.

    Examples
    --------
    >>> from dask_kubernetes import HelmCluster
    >>> cluster = HelmCluster(release_name="myhelmrelease")

    You can then resize the cluster with the scale method

    >>> cluster.scale(10)

    You can pass this cluster directly to a Dask client

    >>> from dask.distributed import Client
    >>> client = Client(cluster)

    You can also access cluster logs

    >>> cluster.get_logs()

    See Also
    --------
    HelmCluster.scale
    HelmCluster.logs
    """
    def __init__(
        self,
        release_name=None,
        auth=ClusterAuth.DEFAULT,
        namespace=None,
        port_forward_cluster_ip=False,
        loop=None,
        asynchronous=False,
        scheduler_name="scheduler",
        worker_name="worker",
        node_host=None,
        node_port=None,
        **kwargs,
    ):
        self.release_name = release_name
        self.namespace = namespace or namespace_default()
        check_dependency("helm")
        check_dependency("kubectl")
        status = subprocess.run(
            ["helm", "-n", self.namespace, "status", self.release_name],
            capture_output=True,
            encoding="utf-8",
        )
        if status.returncode != 0:
            raise RuntimeError(f"No such helm release {self.release_name}.")
        self.auth = auth
        self.namespace
        self.core_api = None
        self.scheduler_comm = None
        self.port_forward_cluster_ip = port_forward_cluster_ip
        self._supports_scaling = True
        self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
        self.loop = self._loop_runner.loop
        self.scheduler_name = scheduler_name
        self.worker_name = worker_name
        self.node_host = node_host
        self.node_port = node_port

        super().__init__(asynchronous=asynchronous, **kwargs)
        if not self.asynchronous:
            self._loop_runner.start()
            self.sync(self._start)

    async def _start(self):
        await ClusterAuth.load_first(self.auth)
        self.core_api = kubernetes.client.CoreV1Api()
        self.apps_api = kubernetes.client.AppsV1Api()
        self.scheduler_comm = rpc(await self._get_scheduler_address())
        await super()._start()

    async def _get_scheduler_address(self):
        # Get the chart name
        chart = subprocess.check_output(
            [
                "helm",
                "-n",
                self.namespace,
                "list",
                "-f",
                self.release_name,
                "--output",
                "json",
            ],
            encoding="utf-8",
        )
        chart = json.loads(chart)[0]["chart"]
        # extract name from {{.Chart.Name }}-{{ .Chart.Version }}
        chart_name = "-".join(chart.split("-")[:-1])
        # Follow the spec in the dask/dask helm chart
        self.chart_name = (f"{chart_name}-"
                           if chart_name not in self.release_name else "")

        service_name = f"{self.release_name}-{self.chart_name}{self.scheduler_name}"
        service = await self.core_api.read_namespaced_service(
            service_name, self.namespace)
        address = await get_external_address_for_scheduler_service(
            self.core_api,
            service,
            port_forward_cluster_ip=self.port_forward_cluster_ip)
        if address is None:
            raise RuntimeError("Unable to determine scheduler address.")
        return address

    async def _wait_for_workers(self):
        while True:
            n_workers = len(self.scheduler_info["workers"])
            deployment = await self.apps_api.read_namespaced_deployment(
                name=f"{self.release_name}-{self.chart_name}{self.worker_name}",
                namespace=self.namespace,
            )
            deployment_replicas = deployment.spec.replicas
            if n_workers == deployment_replicas:
                return
            else:
                await asyncio.sleep(0.2)

    def get_logs(self):
        """Get logs for Dask scheduler and workers.

        Examples
        --------
        >>> cluster.get_logs()
        {'testdask-scheduler-5c8ffb6b7b-sjgrg': ...,
        'testdask-worker-64c8b78cc-992z8': ...,
        'testdask-worker-64c8b78cc-hzpdc': ...,
        'testdask-worker-64c8b78cc-wbk4f': ...}

        Each log will be a string of all logs for that container. To view
        it is recommeded that you print each log.

        >>> print(cluster.get_logs()["testdask-scheduler-5c8ffb6b7b-sjgrg"])
        ...
        distributed.scheduler - INFO - -----------------------------------------------
        distributed.scheduler - INFO - Clear task state
        distributed.scheduler - INFO -   Scheduler at:     tcp://10.1.6.131:8786
        distributed.scheduler - INFO -   dashboard at:                     :8787
        ...
        """
        return self.sync(self._get_logs)

    async def _get_logs(self):
        logs = Logs()

        pods = await self.core_api.list_namespaced_pod(
            namespace=self.namespace,
            label_selector=f"release={self.release_name},app=dask",
        )

        for pod in pods.items:
            if "scheduler" in pod.metadata.name or "worker" in pod.metadata.name:
                try:
                    if pod.status.phase != "Running":
                        raise ValueError(
                            f"Cannot get logs for pod with status {pod.status.phase}.",
                        )
                    log = Log(await self.core_api.read_namespaced_pod_log(
                        pod.metadata.name, pod.metadata.namespace))
                except (ValueError, kubernetes.client.exceptions.ApiException):
                    log = Log(f"Cannot find logs. Pod is {pod.status.phase}.")
                logs[pod.metadata.name] = log

        return logs

    def __await__(self):
        async def _():
            if self.status == Status.created:
                await self._start()
            elif self.status == Status.running:
                await self._wait_for_workers()
            return self

        return _().__await__()

    def scale(self, n_workers):
        """Scale cluster to n workers.

        This sets the Dask worker deployment size to the requested number.
        Workers will not be terminated gracefull so be sure to only scale down
        when all futures have been retrieved by the client and the cluster is idle.

        Examples
        --------

        >>> cluster
        HelmCluster('tcp://localhost:8786', workers=3, threads=18, memory=18.72 GB)
        >>> cluster.scale(4)
        >>> cluster
        HelmCluster('tcp://localhost:8786', workers=4, threads=24, memory=24.96 GB)

        """
        return self.sync(self._scale, n_workers)

    async def _scale(self, n_workers):
        await self.apps_api.patch_namespaced_deployment(
            name=f"{self.release_name}-{self.chart_name}{self.worker_name}",
            namespace=self.namespace,
            body={"spec": {
                "replicas": n_workers,
            }},
        )

    def adapt(self, *args, **kwargs):
        """Turn on adaptivity (Not recommended)."""
        raise NotImplementedError(
            "It is not recommended to run ``HelmCluster`` in adaptive mode."
            "When scaling down workers the decision on which worker to remove is left to Kubernetes, which"
            "will not necessarily remove the same worker that Dask would choose. This may result in lost futures and"
            "recalculation. It is recommended to manage scaling yourself with the ``HelmCluster.scale`` method."
        )

    async def _adapt(self, *args, **kwargs):
        return super().adapt(*args, **kwargs)