def test_close_loop_sync(with_own_loop): loop_runner = loop = None # Setup simple cluster with one threaded worker. # Complex setup is not required here since we test only IO loop teardown. cluster_params = dict(n_workers=1, dashboard_address=None, processes=False) loops_before = LoopRunner._all_loops.copy() # Start own loop or use current thread's one. if with_own_loop: loop_runner = LoopRunner() loop_runner.start() loop = loop_runner.loop with LocalCluster(loop=loop, **cluster_params) as cluster: with Client(cluster, loop=loop) as client: client.run(max, 1, 2) # own loop must be explicitly stopped. if with_own_loop: loop_runner.stop() # Internal loops registry must the same as before cluster running. # This means loop runners in LocalCluster and Client correctly stopped. # See LoopRunner._stop_unlocked(). assert loops_before == LoopRunner._all_loops
async def test_loop_runner_gen(): runner = LoopRunner(asynchronous=True) assert runner.loop is IOLoop.current() assert not runner.is_started() await asyncio.sleep(0.01) runner.start() assert runner.is_started() await asyncio.sleep(0.01) runner.stop() assert not runner.is_started() await asyncio.sleep(0.01)
def test_loop_runner_gen(): runner = LoopRunner(asynchronous=True) assert runner.loop is IOLoop.current() assert not runner.is_started() yield gen.sleep(0.01) runner.start() assert runner.is_started() yield gen.sleep(0.01) runner.stop() assert not runner.is_started() yield gen.sleep(0.01)
class ContextCluster(Cluster): def __init__(self, asynchronous=False, loop=None): self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop super().__init__(asynchronous=asynchronous) if not self.asynchronous: self._loop_runner.start() self.sync(self._start) def __enter__(self): if self.status != "running": raise ValueError( f"Expected status 'running', found '{self.status}'") return self def __exit__(self, typ, value, traceback): self.close() self._loop_runner.stop()
def test_two_loop_runners(loop_in_thread): # Loop runners tied to the same loop should cooperate # ABCCBA loop = IOLoop() a = LoopRunner(loop=loop) b = LoopRunner(loop=loop) assert_not_running(loop) a.start() assert_running(loop) c = LoopRunner(loop=loop) b.start() assert_running(loop) c.start() assert_running(loop) c.stop() assert_running(loop) b.stop() assert_running(loop) a.stop() assert_not_running(loop) # ABCABC loop = IOLoop() a = LoopRunner(loop=loop) b = LoopRunner(loop=loop) assert_not_running(loop) a.start() assert_running(loop) b.start() assert_running(loop) c = LoopRunner(loop=loop) c.start() assert_running(loop) a.stop() assert_running(loop) b.stop() assert_running(loop) c.stop() assert_not_running(loop) # Explicit loop, already started a = LoopRunner(loop=loop_in_thread) b = LoopRunner(loop=loop_in_thread) assert_running(loop_in_thread) a.start() assert_running(loop_in_thread) b.start() assert_running(loop_in_thread) a.stop() assert_running(loop_in_thread) b.stop() assert_running(loop_in_thread)
def test_loop_runner(loop_in_thread): # Implicit loop loop = IOLoop() loop.make_current() runner = LoopRunner() assert runner.loop not in (loop, loop_in_thread) assert not runner.is_started() assert_not_running(runner.loop) runner.start() assert runner.is_started() assert_running(runner.loop) runner.stop() assert not runner.is_started() assert_not_running(runner.loop) # Explicit loop loop = IOLoop() runner = LoopRunner(loop=loop) assert runner.loop is loop assert not runner.is_started() assert_not_running(loop) runner.start() assert runner.is_started() assert_running(loop) runner.stop() assert not runner.is_started() assert_not_running(loop) # Explicit loop, already started runner = LoopRunner(loop=loop_in_thread) assert not runner.is_started() assert_running(loop_in_thread) runner.start() assert runner.is_started() assert_running(loop_in_thread) runner.stop() assert not runner.is_started() assert_running(loop_in_thread) # Implicit loop, asynchronous=True loop = IOLoop() loop.make_current() runner = LoopRunner(asynchronous=True) assert runner.loop is loop assert not runner.is_started() assert_not_running(runner.loop) runner.start() assert runner.is_started() assert_not_running(runner.loop) runner.stop() assert not runner.is_started() assert_not_running(runner.loop) # Explicit loop, asynchronous=True loop = IOLoop() runner = LoopRunner(loop=loop, asynchronous=True) assert runner.loop is loop assert not runner.is_started() assert_not_running(runner.loop) runner.start() assert runner.is_started() assert_not_running(runner.loop) runner.stop() assert not runner.is_started() assert_not_running(runner.loop)
class YarnCluster(object): """Start a Dask cluster on YARN. You can define default values for this in Dask's ``yarn.yaml`` configuration file. See http://docs.dask.org/en/latest/configuration.html for more information. Parameters ---------- environment : str, optional The Python environment to use. Can be one of the following: - A path to an archived Python environment - A path to a conda environment, specified as `conda:///...` - A path to a virtual environment, specified as `venv:///...` - A path to a python executable, specifed as `python:///...` Note that if not an archive, the paths specified must be valid on all nodes in the cluster. n_workers : int, optional The number of workers to initially start. worker_vcores : int, optional The number of virtual cores to allocate per worker. worker_memory : str, optional The amount of memory to allocate per worker. Accepts a unit suffix (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest MiB. worker_restarts : int, optional The maximum number of worker restarts to allow before failing the application. Default is unlimited. worker_env : dict, optional A mapping of environment variables to their values. These will be set in the worker containers before starting the dask workers. scheduler_vcores : int, optional The number of virtual cores to allocate per scheduler. scheduler_memory : str, optional The amount of memory to allocate to the scheduler. Accepts a unit suffix (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest MiB. deploy_mode : {'remote', 'local'}, optional The deploy mode to use. If ``'remote'``, the scheduler will be deployed in a YARN container. If ``'local'``, the scheduler will run locally, which can be nice for debugging. Default is ``'remote'``. name : str, optional The application name. queue : str, optional The queue to deploy to. tags : sequence, optional A set of strings to use as tags for this application. user : str, optional The user to submit the application on behalf of. Default is the current user - submitting as a different user requires user permissions, see the YARN documentation for more information. host : str, optional Host address on which the scheduler will listen. Only used if ``deploy_mode='local'``. Defaults to ``'0.0.0.0'``. port : int, optional The port on which the scheduler will listen. Only used if ``deploy_mode='local'``. Defaults to ``0`` for a random port. dashboard_address : str Address on which to the dashboard server will listen. Only used if ``deploy_mode='local'``. Defaults to ':0' for a random port. skein_client : skein.Client, optional The ``skein.Client`` to use. If not provided, one will be started. asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. Examples -------- >>> cluster = YarnCluster(environment='my-env.tar.gz', ...) >>> cluster.scale(10) """ def __init__( self, environment=None, n_workers=None, worker_vcores=None, worker_memory=None, worker_restarts=None, worker_env=None, scheduler_vcores=None, scheduler_memory=None, deploy_mode=None, name=None, queue=None, tags=None, user=None, host=None, port=None, dashboard_address=None, skein_client=None, asynchronous=False, loop=None, ): spec = _make_specification( environment=environment, n_workers=n_workers, worker_vcores=worker_vcores, worker_memory=worker_memory, worker_restarts=worker_restarts, worker_env=worker_env, scheduler_vcores=scheduler_vcores, scheduler_memory=scheduler_memory, deploy_mode=deploy_mode, name=name, queue=queue, tags=tags, user=user, ) self._init_common( spec=spec, host=host, port=port, dashboard_address=dashboard_address, asynchronous=asynchronous, loop=loop, skein_client=skein_client, ) @classmethod def from_specification(cls, spec, skein_client=None, asynchronous=False, loop=None): """Start a dask cluster from a skein specification. Parameters ---------- spec : skein.ApplicationSpec, dict, or filename The application specification to use. Must define at least one service: ``'dask.worker'``. If no ``'dask.scheduler'`` service is defined, a scheduler will be started locally. skein_client : skein.Client, optional The ``skein.Client`` to use. If not provided, one will be started. asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ self = super(YarnCluster, cls).__new__(cls) if isinstance(spec, dict): spec = skein.ApplicationSpec.from_dict(spec) elif isinstance(spec, str): spec = skein.ApplicationSpec.from_file(spec) elif not isinstance(spec, skein.ApplicationSpec): raise TypeError("spec must be an ApplicationSpec, dict, or path, " "got %r" % type(spec).__name__) if "dask.worker" not in spec.services: raise ValueError( "Provided Skein specification must include a 'dask.worker' service" ) self._init_common(spec=spec, asynchronous=asynchronous, loop=loop, skein_client=skein_client) return self @classmethod def from_current(cls, asynchronous=False, loop=None): """Connect to an existing ``YarnCluster`` from inside the cluster. Parameters ---------- asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. Returns ------- YarnCluster """ self = super(YarnCluster, cls).__new__(cls) app_id = os.environ.get("DASK_APPLICATION_ID", None) app_address = os.environ.get("DASK_APPMASTER_ADDRESS", None) security_dir = os.environ.get("DASK_SECURITY_CREDENTIALS", None) if app_id is not None and app_address is not None: security = (None if security_dir is None else skein.Security.from_directory(security_dir)) app = skein.ApplicationClient(app_address, app_id, security=security) else: app = skein.ApplicationClient.from_current() self._init_common(application_client=app, asynchronous=asynchronous, loop=loop) return self @classmethod def from_application_id(cls, app_id, skein_client=None, asynchronous=False, loop=None): """Connect to an existing ``YarnCluster`` with a given application id. Parameters ---------- app_id : str The existing cluster's application id. skein_client : skein.Client The ``skein.Client`` to use. If not provided, one will be started. asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. Returns ------- YarnCluster """ self = super(YarnCluster, cls).__new__(cls) skein_client = _get_skein_client(skein_client) app = skein_client.connect(app_id) self._init_common( application_client=app, asynchronous=asynchronous, loop=loop, skein_client=skein_client, ) return self def _init_common( self, spec=None, application_client=None, host=None, port=None, dashboard_address=None, asynchronous=False, loop=None, skein_client=None, ): self.spec = spec self.application_client = application_client self._scheduler_kwargs = _make_scheduler_kwargs( host=host, port=port, dashboard_address=dashboard_address, ) self._scheduler = None self.scheduler_info = {} self._requested = set() self.scheduler_comm = None self._watch_worker_status_task = None self._start_task = None self._stop_task = None self._finalizer = None self._adaptive = None self._adaptive_options = {} self._skein_client = skein_client self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self._loop_runner.start() if not self.asynchronous: self._sync(self._start_internal()) def _start_cluster(self): """Start the cluster and initialize state""" skein_client = _get_skein_client(self._skein_client) if "dask.scheduler" not in self.spec.services: # deploy_mode == 'local' scheduler_address = self._scheduler.address for k in ["dashboard", "bokeh"]: if k in self._scheduler.services: dashboard_port = self._scheduler.services[k].port dashboard_host = urlparse(scheduler_address).hostname dashboard_address = "http://%s:%d" % ( dashboard_host, dashboard_port, ) break else: dashboard_address = None with submit_and_handle_failures(skein_client, self.spec) as app: app.kv["dask.scheduler"] = scheduler_address.encode() if dashboard_address is not None: app.kv["dask.dashboard"] = dashboard_address.encode() else: # deploy_mode == 'remote' with submit_and_handle_failures(skein_client, self.spec) as app: scheduler_address = app.kv.wait("dask.scheduler").decode() dashboard_address = app.kv.get("dask.dashboard") if dashboard_address is not None: dashboard_address = dashboard_address.decode() # Ensure application gets cleaned up self._finalizer = weakref.finalize(self, app.shutdown) self.scheduler_address = scheduler_address self._dashboard_address = dashboard_address self.application_client = app def _connect_existing(self): spec = self.application_client.get_specification() if "dask.worker" not in spec.services: raise ValueError("%r is not a valid dask cluster" % self.app_id) scheduler_address = self.application_client.kv.wait( "dask.scheduler").decode() dashboard_address = self.application_client.kv.get("dask.dashboard") if dashboard_address is not None: dashboard_address = dashboard_address.decode() self.spec = spec self.scheduler_address = scheduler_address self._dashboard_address = dashboard_address async def _start_internal(self): if self._start_task is None: self._start_task = asyncio.ensure_future(self._start_async()) try: await self._start_task except BaseException: # On exception, cleanup await self._stop_internal() raise return self async def _start_async(self): if self.spec is not None: # Start a new cluster if "dask.scheduler" not in self.spec.services: self._scheduler = Scheduler( loop=self.loop, **self._scheduler_kwargs, ) await self._scheduler else: self._scheduler = None await self.loop.run_in_executor(None, self._start_cluster) else: # Connect to an existing cluster await self.loop.run_in_executor(None, self._connect_existing) self.scheduler_comm = rpc(self.scheduler_address) comm = None try: comm = await self.scheduler_comm.live_comm() await comm.write({"op": "subscribe_worker_status"}) self.scheduler_info = await comm.read() workers = self.scheduler_info.get("workers", {}) self._requested.update(w["name"] for w in workers.values()) self._watch_worker_status_task = asyncio.ensure_future( self._watch_worker_status(comm)) except Exception: if comm is not None: await comm.close() async def _stop_internal(self, status="SUCCEEDED", diagnostics=None): if self._stop_task is None: self._stop_task = asyncio.ensure_future( self._stop_async(status=status, diagnostics=diagnostics)) await self._stop_task async def _stop_async(self, status="SUCCEEDED", diagnostics=None): if self._start_task is not None: if not self._start_task.done(): # We're still starting, cancel task await cancel_task(self._start_task) self._start_task = None if self._adaptive is not None: self._adaptive.stop() if self._watch_worker_status_task is not None: await cancel_task(self._watch_worker_status_task) self._watch_worker_status_task = None if self.scheduler_comm is not None: self.scheduler_comm.close_rpc() self.scheduler_comm = None await self.loop.run_in_executor( None, lambda: self._stop_sync(status=status, diagnostics=diagnostics)) if self._scheduler is not None: await self._scheduler.close() self._scheduler = None def _stop_sync(self, status="SUCCEEDED", diagnostics=None): if self._finalizer is not None and self._finalizer.peek() is not None: self.application_client.shutdown(status=status, diagnostics=diagnostics) self._finalizer.detach() # don't run the finalizer later self._finalizer = None def __await__(self): return self.__aenter__().__await__() async def __aenter__(self): return await self._start_internal() async def __aexit__(self, typ, value, traceback): await self._stop_internal() def __enter__(self): return self def __exit__(self, *args): self.close() def __repr__(self): return "YarnCluster<%s>" % self.app_id @property def loop(self): return self._loop_runner.loop @property def app_id(self): return self.application_client.id @property def asynchronous(self): return self._asynchronous def _sync(self, task): if self.asynchronous: return task future = asyncio.run_coroutine_threadsafe(task, self.loop.asyncio_loop) try: return future.result() except BaseException: future.cancel() raise @cached_property def dashboard_link(self): """Link to the dask dashboard. None if dashboard isn't running""" if self._dashboard_address is None: return None dashboard = urlparse(self._dashboard_address) return format_dashboard_link(dashboard.hostname, dashboard.port) def shutdown(self, status="SUCCEEDED", diagnostics=None): """Shutdown the application. Parameters ---------- status : {'SUCCEEDED', 'FAILED', 'KILLED'}, optional The yarn application exit status. diagnostics : str, optional The application exit message, usually used for diagnosing failures. Can be seen in the YARN Web UI for completed applications under "diagnostics". If not provided, a default will be used. """ if self.asynchronous: return self._stop_internal(status=status, diagnostics=diagnostics) if self.loop.asyncio_loop.is_running() and not sys.is_finalizing(): self._sync( self._stop_internal(status=status, diagnostics=diagnostics)) else: # Always run this! self._stop_sync(status=status, diagnostics=diagnostics) self._loop_runner.stop() def close(self, **kwargs): """Close this cluster. An alias for ``shutdown``. See Also -------- shutdown """ return self.shutdown(**kwargs) def __del__(self): if not hasattr(self, "_loop_runner"): return if self.asynchronous: # No del for async mode return self.close() @property def _observed(self): return {w["name"] for w in self.scheduler_info["workers"].values()} async def _workers(self): return await self.loop.run_in_executor( None, lambda: self.application_client.get_containers(services= ["dask.worker"]), ) def workers(self): """A list of all currently running worker containers.""" return self._sync(self._workers()) async def _logs(self, scheduler=True, workers=True): logs = Logs() if scheduler: slogs = await self.scheduler_comm.logs() logs["Scheduler"] = Log("\n".join(line for level, line in slogs)) if workers: d = await self.scheduler_comm.worker_logs(workers=workers) for k, v in d.items(): logs[k] = Log("\n".join(line for level, line in v)) return logs def logs(self, scheduler=True, workers=True): """Return logs for the scheduler and/or workers Parameters ---------- scheduler : boolean, optional Whether or not to collect logs for the scheduler workers : boolean or iterable, optional A list of worker addresses to select. Defaults to all workers if ``True`` or no workers if ``False`` Returns ------- logs : dict A dictionary of name -> logs. """ return self._sync(self._logs(scheduler=scheduler, workers=workers)) async def _scale_up(self, n): if n > len(self._requested): containers = await self.loop.run_in_executor( None, lambda: self.application_client.scale("dask.worker", n), ) self._requested.update(c.id for c in containers) async def _scale_down(self, workers): self._requested.difference_update(workers) await self.scheduler_comm.retire_workers(names=list(workers)) def _kill_containers(): for c in workers: try: self.application_client.kill_container(c) except ValueError: pass await self.loop.run_in_executor( None, _kill_containers, ) async def _scale(self, n): if self._adaptive is not None: self._adaptive.stop() if n >= len(self._requested): return await self._scale_up(n) else: n_to_delete = len(self._requested) - n pending = list(self._requested - self._observed) running = list(self._observed.difference(pending)) to_close = pending[:n_to_delete] to_close.extend(running[:n_to_delete - len(to_close)]) await self._scale_down(to_close) self._requested.difference_update(to_close) def scale(self, n): """Scale cluster to n workers. Parameters ---------- n : int Target number of workers Examples -------- >>> cluster.scale(10) # scale cluster to ten workers """ return self._sync(self._scale(n)) def adapt(self, minimum=0, maximum=math.inf, interval="1s", wait_count=3, target_duration="5s", **kwargs): """Turn on adaptivity This scales Dask clusters automatically based on scheduler activity. Parameters ---------- minimum : int, optional Minimum number of workers. Defaults to ``0``. maximum : int, optional Maximum number of workers. Defaults to ``inf``. interval : timedelta or str, optional Time between worker add/remove recommendations. wait_count : int, optional Number of consecutive times that a worker should be suggested for removal before we remove it. target_duration : timedelta or str, optional Amount of time we want a computation to take. This affects how aggressively we scale up. **kwargs : Additional parameters to pass to ``distributed.Scheduler.workers_to_close``. Examples -------- >>> cluster.adapt(minimum=0, maximum=10) """ if self._adaptive is not None: self._adaptive.stop() self._adaptive_options.update( minimum=minimum, maximum=maximum, interval=interval, wait_count=wait_count, target_duration=target_duration, **kwargs, ) self._adaptive = Adaptive(self, **self._adaptive_options) async def _watch_worker_status(self, comm): # We don't want to hold on to a ref to self, otherwise this will # leave a dangling reference and prevent garbage collection. ref_self = weakref.ref(self) self = None try: while True: try: msgs = await comm.read() except OSError: break try: self = ref_self() if self is None: break for op, msg in msgs: if op == "add": workers = msg.pop("workers") self.scheduler_info["workers"].update(workers) self.scheduler_info.update(msg) self._requested.update(w["name"] for w in workers.values()) elif op == "remove": del self.scheduler_info["workers"][msg] self._requested.discard(msg) if hasattr(self, "_status_widget"): self._status_widget.value = self._widget_status() finally: self = None finally: await comm.close() def _widget_status(self): try: workers = self.scheduler_info["workers"] except KeyError: return None else: n_workers = len(workers) cores = sum(w["nthreads"] for w in workers.values()) memory = sum(w["memory_limit"] for w in workers.values()) return _widget_status_template % (n_workers, cores, format_bytes(memory)) def _widget(self): """ Create IPython widget for display within a notebook """ try: return self._cached_widget except AttributeError: pass if self.asynchronous: return None try: from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion except ImportError: self._cached_widget = None return None layout = Layout(width="150px") title = HTML("<h2>YarnCluster</h2>") status = HTML(self._widget_status(), layout=Layout(min_width="150px")) request = IntText(0, description="Workers", layout=layout) scale = Button(description="Scale", layout=layout) minimum = IntText(0, description="Minimum", layout=layout) maximum = IntText(0, description="Maximum", layout=layout) adapt = Button(description="Adapt", layout=layout) accordion = Accordion( [HBox([request, scale]), HBox([minimum, maximum, adapt])], layout=Layout(min_width="500px"), ) accordion.selected_index = None accordion.set_title(0, "Manual Scaling") accordion.set_title(1, "Adaptive Scaling") @adapt.on_click def adapt_cb(b): self.adapt(minimum=minimum.value, maximum=maximum.value) @scale.on_click def scale_cb(b): with log_errors(): self.scale(request.value) app_id = HTML("<p><b>Application ID: </b>{0}</p>".format(self.app_id)) elements = [title, HBox([status, accordion]), app_id] if self.dashboard_link is not None: link = HTML( '<p><b>Dashboard: </b><a href="{0}" target="_blank">{0}' "</a></p>\n".format(self.dashboard_link)) elements.append(link) self._cached_widget = box = VBox(elements) self._status_widget = status return box def _ipython_display_(self, **kwargs): widget = self._widget() if widget is not None: return widget._ipython_display_(**kwargs) else: from IPython.display import display data = {"text/plain": repr(self), "text/html": self._repr_html_()} display(data, raw=True) def _repr_html_(self): if self.dashboard_link is not None: dashboard = "<a href='{0}' target='_blank'>{0}</a>".format( self.dashboard_link) else: dashboard = "Not Available" return ( "<div style='background-color: #f2f2f2; display: inline-block; " "padding: 10px; border: 1px solid #999999;'>\n" " <h3>YarnCluster</h3>\n" " <ul>\n" " <li><b>Application ID: </b>{app_id}\n" " <li><b>Dashboard: </b>{dashboard}\n" " </ul>\n" "</div>\n").format(app_id=self.app_id, dashboard=dashboard)
class AzureMLCluster(Cluster): """ Deploy a Dask cluster using Azure ML This creates a dask scheduler and workers on an Azure ML Compute Target. Parameters ---------- workspace: azureml.core.Workspace (required) Azure ML Workspace - see https://aka.ms/azureml/workspace. vm_size: str (optional) Azure VM size to be used in the Compute Target - see https://aka.ms/azureml/vmsizes. datastores: List[Datastore] (optional) List of Azure ML Datastores to be mounted on the headnode - see https://aka.ms/azureml/data and https://aka.ms/azureml/datastores. Defaults to ``[]``. To mount all datastores in the workspace, set to ``ws.datastores.values()``. environment_definition: azureml.core.Environment (optional) Azure ML Environment - see https://aka.ms/azureml/environments. Defaults to the "AzureML-Dask-CPU" or "AzureML-Dask-GPU" curated environment. scheduler_idle_timeout: int (optional) Number of idle seconds leading to scheduler shut down. Defaults to ``1200`` (20 minutes). experiment_name: str (optional) The name of the Azure ML Experiment used to control the cluster. Defaults to ``dask-cloudprovider``. initial_node_count: int (optional) The initial number of nodes for the Dask Cluster. Defaults to ``1``. jupyter: bool (optional) Flag to start JupyterLab session on the headnode of the cluster. Defaults to ``False``. jupyter_port: int (optional) Port on headnode to use for hosting JupyterLab session. Defaults to ``9000``. dashboard_port: int (optional) Port on headnode to use for hosting Dask dashboard. Defaults to ``9001``. scheduler_port: int (optional) Port to map the scheduler port to via SSH-tunnel if machine not on the same VNET. Defaults to ``9002``. worker_death_timeout: int (optional) Number of seconds to wait for a worker to respond before removing it. Defaults to ``30``. additional_ports: list[tuple[int, int]] (optional) Additional ports to forward. This requires a list of tuples where the first element is the port to open on the headnode while the second element is the port to map to or forward via the SSH-tunnel. Defaults to ``[]``. compute_target: azureml.core.ComputeTarget (optional) Azure ML Compute Target - see https://aka.ms/azureml/computetarget. admin_username: str (optional) Username of the admin account for the AzureML Compute. Required for runs that are not on the same VNET. Defaults to empty string. Throws Exception if machine not on the same VNET. Defaults to ``""``. admin_ssh_key: str (optional) Location of the SSH secret key used when creating the AzureML Compute. The key should be passwordless if run from a Jupyter notebook. The ``id_rsa`` file needs to have 0700 permissions set. Required for runs that are not on the same VNET. Defaults to empty string. Throws Exception if machine not on the same VNET. Defaults to ``""``. vnet: str (optional) Name of the virtual network. subnet: str (optional) Name of the subnet inside the virtual network ``vnet``. vnet_resource_group: str (optional) Name of the resource group where the virtual network ``vnet`` is located. If not passed, but names for ``vnet`` and ``subnet`` are passed, ``vnet_resource_group`` is assigned with the name of resource group associated with ``workspace`` telemetry_opt_out: bool (optional) A boolean parameter. Defaults to logging a version of AzureMLCluster with Microsoft. Set this flag to False if you do not want to share this information with Microsoft. Microsoft is not tracking anything else you do in your Dask cluster nor any other information related to your workload. asynchronous: bool (optional) Flag to run jobs asynchronously. **kwargs: dict Additional keyword arguments. """ def __init__( self, workspace, compute_target=None, environment_definition=None, experiment_name=None, initial_node_count=None, jupyter=None, jupyter_port=None, dashboard_port=None, scheduler_port=None, scheduler_idle_timeout=None, worker_death_timeout=None, additional_ports=None, admin_username=None, admin_ssh_key=None, datastores=None, code_store=None, vnet_resource_group=None, vnet=None, subnet=None, show_output=False, telemetry_opt_out=None, asynchronous=False, **kwargs, ): ### REQUIRED PARAMETERS self.workspace = workspace self.compute_target = compute_target ### ENVIRONMENT self.environment_definition = environment_definition ### EXPERIMENT DEFINITION self.experiment_name = experiment_name self.tags = {"tag": "azureml-dask"} ### ENVIRONMENT AND VARIABLES self.initial_node_count = initial_node_count ### SEND TELEMETRY self.telemetry_opt_out = telemetry_opt_out self.telemetry_set = False ### FUTURE EXTENSIONS self.kwargs = kwargs self.show_output = show_output ## CREATE COMPUTE TARGET self.admin_username = admin_username self.admin_ssh_key = admin_ssh_key self.vnet_resource_group = vnet_resource_group self.vnet = vnet self.subnet = subnet self.compute_target_set = True self.pub_key_file = "" self.pri_key_file = "" if self.compute_target is None: try: self.compute_target = self.__create_compute_target() self.compute_target_set = False except Exception as e: logger.exception(e) return elif self.compute_target.admin_user_ssh_key is not None and ( self.admin_ssh_key is None or self.admin_username is None ): logger.exception( "Please provide private key and admin username to access compute target {}".format( self.compute_target.name ) ) return ### GPU RUN INFO self.workspace_vm_sizes = AmlCompute.supported_vmsizes(self.workspace) self.workspace_vm_sizes = [ (e["name"].lower(), e["gpus"]) for e in self.workspace_vm_sizes ] self.workspace_vm_sizes = dict(self.workspace_vm_sizes) self.compute_target_vm_size = self.compute_target.serialize()["properties"][ "status" ]["vmSize"].lower() self.n_gpus_per_node = self.workspace_vm_sizes[self.compute_target_vm_size] self.use_gpu = True if self.n_gpus_per_node > 0 else False if self.environment_definition is None: if self.use_gpu: self.environment_definition = self.workspace.environments[ "AzureML-Dask-GPU" ] else: self.environment_definition = self.workspace.environments[ "AzureML-Dask-CPU" ] ### JUPYTER AND PORT FORWARDING self.jupyter = jupyter self.jupyter_port = jupyter_port self.dashboard_port = dashboard_port self.scheduler_port = scheduler_port self.scheduler_idle_timeout = scheduler_idle_timeout self.portforward_proc = None self.worker_death_timeout = worker_death_timeout self.end_logging = False # FLAG FOR STOPPING THE port_forward_logger THREAD if additional_ports is not None: if type(additional_ports) != list: error_message = ( f"The additional_ports parameter is of {type(additional_ports)}" " type but needs to be a list of int tuples." " Check the documentation." ) logger.exception(error_message) raise TypeError(error_message) if len(additional_ports) > 0: if type(additional_ports[0]) != tuple: error_message = ( f"The additional_ports elements are of {type(additional_ports[0])}" " type but needs to be a list of int tuples." " Check the documentation." ) raise TypeError(error_message) ### check if all elements are tuples of length two and int type all_correct = True for el in additional_ports: if type(el) != tuple or len(el) != 2: all_correct = False break if (type(el[0]), type(el[1])) != (int, int): all_correct = False break if not all_correct: error_message = ( "At least one of the elements of the additional_ports parameter" " is wrong. Make sure it is a list of int tuples." " Check the documentation." ) raise TypeError(error_message) self.additional_ports = additional_ports self.scheduler_ip_port = ( None ### INIT FOR HOLDING THE ADDRESS FOR THE SCHEDULER ) ### DATASTORES self.datastores = datastores ### RUNNING IN MATRIX OR LOCAL self.same_vnet = None self.is_in_ci = False ### GET RUNNING LOOP self._loop_runner = LoopRunner(loop=None, asynchronous=asynchronous) self.loop = self._loop_runner.loop self.abs_path = pathlib.Path(__file__).parent.absolute() ### INITIALIZE CLUSTER super().__init__(asynchronous=asynchronous) if not self.asynchronous: self._loop_runner.start() self.sync(self.__get_defaults) if not self.telemetry_opt_out: self.__append_telemetry() self.sync(self.__create_cluster) async def __get_defaults(self): self.config = dask.config.get("cloudprovider.azure", {}) if self.experiment_name is None: self.experiment_name = self.config.get("experiment_name") if self.initial_node_count is None: self.initial_node_count = self.config.get("initial_node_count") if self.jupyter is None: self.jupyter = self.config.get("jupyter") if self.jupyter_port is None: self.jupyter_port = self.config.get("jupyter_port") if self.dashboard_port is None: self.dashboard_port = self.config.get("dashboard_port") if self.scheduler_port is None: self.scheduler_port = self.config.get("scheduler_port") if self.scheduler_idle_timeout is None: self.scheduler_idle_timeout = self.config.get("scheduler_idle_timeout") if self.worker_death_timeout is None: self.worker_death_timeout = self.config.get("worker_death_timeout") if self.additional_ports is None: self.additional_ports = self.config.get("additional_ports") if self.admin_username is None: self.admin_username = self.config.get("admin_username") if self.admin_ssh_key is None: self.admin_ssh_key = self.config.get("admin_ssh_key") if self.datastores is None: self.datastores = self.config.get("datastores") if self.telemetry_opt_out is None: self.telemetry_opt_out = self.config.get("telemetry_opt_out") ### PARAMETERS TO START THE CLUSTER self.scheduler_params = {} self.worker_params = {} ### scheduler and worker parameters self.scheduler_params["--jupyter"] = self.jupyter self.scheduler_params["--scheduler_idle_timeout"] = self.scheduler_idle_timeout self.worker_params["--worker_death_timeout"] = self.worker_death_timeout if self.use_gpu: self.scheduler_params["--use_gpu"] = True self.scheduler_params["--n_gpus_per_node"] = self.n_gpus_per_node self.worker_params["--use_gpu"] = True self.worker_params["--n_gpus_per_node"] = self.n_gpus_per_node ### CLUSTER PARAMS self.max_nodes = self.compute_target.serialize()["properties"]["properties"][ "scaleSettings" ]["maxNodeCount"] self.scheduler_ip_port = None self.workers_list = [] self.URLs = {} ### SANITY CHECKS ###-----> initial node count if self.initial_node_count > self.max_nodes: self.initial_node_count = self.max_nodes def __append_telemetry(self): if not self.telemetry_set: self.telemetry_set = True try: from azureml._base_sdk_common.user_agent import append append("AzureMLCluster-DASK", "0.1") except ImportError: pass def __print_message(self, msg, length=80, filler="#", pre_post=""): logger.info(msg) if self.show_output: print(f"{pre_post} {msg} {pre_post}".center(length, filler)) async def __check_if_scheduler_ip_reachable(self): """ Private method to determine if running in the cloud within the same VNET and the scheduler node is reachable """ try: ip, port = self.scheduler_ip_port.split(":") socket.create_connection((ip, port), 20) self.same_vnet = True self.__print_message("On the same VNET") except socket.timeout as e: self.__print_message("Not on the same VNET") self.same_vnet = False except ConnectionRefusedError as e: logger.info(e) self.__print_message(e) pass def __prepare_rpc_connection_to_headnode(self): if self.same_vnet: return self.run.get_metrics()["scheduler"] elif self.is_in_ci: uri = f"{self.hostname}:{self.scheduler_port}" return uri else: uri = f"localhost:{self.scheduler_port}" self.hostname = "localhost" logger.info(f"Local connection: {uri}") return uri def __get_ssh_keys(self): from cryptography.hazmat.primitives import serialization as crypto_serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import ( default_backend as crypto_default_backend, ) dir_path = os.path.join(os.getcwd(), "tmp") if not os.path.exists(dir_path): os.makedirs(dir_path) pub_key_file = os.path.join(dir_path, "key.pub") pri_key_file = os.path.join(dir_path, "key") key = rsa.generate_private_key( backend=crypto_default_backend(), public_exponent=65537, key_size=2048 ) private_key = key.private_bytes( crypto_serialization.Encoding.PEM, crypto_serialization.PrivateFormat.PKCS8, crypto_serialization.NoEncryption(), ) public_key = key.public_key().public_bytes( crypto_serialization.Encoding.OpenSSH, crypto_serialization.PublicFormat.OpenSSH, ) with open(pub_key_file, "wb") as f: f.write(public_key) with open(pri_key_file, "wb") as f: f.write(private_key) os.chmod(pri_key_file, 0o600) with open(pub_key_file, "r") as f: pubkey = f.read() self.pub_key_file = pub_key_file self.pri_key_file = pri_key_file return pubkey, pri_key_file def __create_compute_target(self): import random tmp_name = "dask-ct-{}".format(random.randint(100000, 999999)) ct_name = self.kwargs.get("ct_name", tmp_name) vm_name = self.kwargs.get("vm_size", "STANDARD_DS3_V2") min_nodes = int(self.kwargs.get("min_nodes", "0")) max_nodes = int(self.kwargs.get("max_nodes", "100")) idle_time = int(self.kwargs.get("idle_time", "300")) vnet_rg = None vnet_name = None subnet_name = None if self.admin_username is None: self.admin_username = "******" ssh_key_pub, self.admin_ssh_key = self.__get_ssh_keys() if self.vnet and self.subnet: vnet_name = self.vnet subnet_name = self.subnet if self.vnet_resource_group: vnet_rg = self.vnet_resource_group else: vnet_rg = self.workspace.resource_group try: if ct_name not in self.workspace.compute_targets: config = AmlCompute.provisioning_configuration( vm_size=vm_name, min_nodes=min_nodes, max_nodes=max_nodes, vnet_resourcegroup_name=vnet_rg, vnet_name=vnet_name, subnet_name=subnet_name, idle_seconds_before_scaledown=idle_time, admin_username=self.admin_username, admin_user_ssh_key=ssh_key_pub, remote_login_port_public_access="Enabled", ) self.__print_message("Creating new compute targe: {}".format(ct_name)) ct = ComputeTarget.create(self.workspace, ct_name, config) ct.wait_for_completion(show_output=self.show_output) else: self.__print_message( "Using existing compute target: {}".format(ct_name) ) ct = self.workspace.compute_targets[ct_name] except Exception as e: logger.exception("Cannot create/get compute target. {}".format(e)) raise e return ct def __delete_compute_target(self): try: self.compute_target.delete() except ComputeTargetException as e: logger.exception( "Compute target {} cannot be removed. You may need to delete it manually. {}".format( self.compute_target.name, e ) ) async def __create_cluster(self): self.__print_message("Setting up cluster") exp = Experiment(self.workspace, self.experiment_name) estimator = Estimator( os.path.join(self.abs_path, "setup"), compute_target=self.compute_target, entry_script="start_scheduler.py", environment_definition=self.environment_definition, script_params=self.scheduler_params, node_count=1, ### start only scheduler distributed_training=MpiConfiguration(), use_docker=True, inputs=self.datastores, ) run = exp.submit(estimator, tags=self.tags) self.__print_message("Waiting for scheduler node's IP") status = run.get_status() while ( status != "Canceled" and status != "Failed" and "scheduler" not in run.get_metrics() ): print(".", end="") logger.info("Scheduler not ready") time.sleep(5) status = run.get_status() if status == "Canceled" or status == "Failed": run_error = run.get_details().get("error") error_message = "Failed to start the AzureML cluster." if run_error: error_message = "{} {}".format(error_message, run_error) logger.exception(error_message) if not self.compute_target_set: self.__delete_compute_target() raise Exception(error_message) print("\n") ### SET FLAGS self.scheduler_ip_port = run.get_metrics()["scheduler"] self.worker_params["--scheduler_ip_port"] = self.scheduler_ip_port self.__print_message(f'Scheduler: {run.get_metrics()["scheduler"]}') self.run = run ### CHECK IF ON THE SAME VNET max_retry = 5 while self.same_vnet is None and max_retry > 0: time.sleep(5) await self.sync(self.__check_if_scheduler_ip_reachable) max_retry -= 1 if self.same_vnet is None: self.run.cancel() if not self.compute_target_set: self.__delete_compute_target() logger.exception( "Connection error after retrying. Failed to start the AzureML cluster." ) return ### REQUIRED BY dask.distributed.deploy.cluster.Cluster self.hostname = socket.gethostname() self.is_in_ci = ( f"/mnt/batch/tasks/shared/LS_root/mounts/clusters/{self.hostname}" in os.getcwd() ) _scheduler = self.__prepare_rpc_connection_to_headnode() self.scheduler_comm = rpc(_scheduler) await self.sync(self.__setup_port_forwarding) try: await super()._start() except Exception as e: logger.exception(e) # CLEAN UP COMPUTE TARGET self.run.cancel() if not self.compute_target_set: self.__delete_compute_target() return await self.sync(self.__update_links) self.__print_message("Connections established") self.__print_message(f"Scaling to {self.initial_node_count} workers") if self.initial_node_count > 1: self.scale( self.initial_node_count ) # LOGIC TO KEEP PROPER TRACK OF WORKERS IN `scale` self.__print_message("Scaling is done") async def __update_links(self): token = self.run.get_metrics()["token"] if self.same_vnet or self.is_in_ci: location = self.workspace.get_details()["location"] self.scheduler_info[ "dashboard_url" ] = f"https://{self.hostname}-{self.dashboard_port}.{location}.instances.azureml.net/status" self.scheduler_info[ "jupyter_url" ] = f"https://{self.hostname}-{self.jupyter_port}.{location}.instances.azureml.net/lab?token={token}" else: hostname = "localhost" self.scheduler_info[ "dashboard_url" ] = f"http://{hostname}:{self.dashboard_port}" self.scheduler_info[ "jupyter_url" ] = f"http://{hostname}:{self.jupyter_port}/?token={token}" logger.info(f'Dashboard URL: {self.scheduler_info["dashboard_url"]}') logger.info(f'Jupyter URL: {self.scheduler_info["jupyter_url"]}') def __port_forward_logger(self, portforward_proc): portforward_log = open("portforward_out_log.txt", "w") while True: portforward_out = portforward_proc.stdout.readline() if portforward_proc != "": portforward_log.write(portforward_out) portforward_log.flush() if self.end_logging: break return async def __setup_port_forwarding(self): dashboard_address = self.run.get_metrics()["dashboard"] jupyter_address = self.run.get_metrics()["jupyter"] scheduler_ip = self.run.get_metrics()["scheduler"].split(":")[0] self.__print_message("Running in compute instance? {}".format(self.is_in_ci)) os.system( "killall socat" ) # kill all socat processes - cleans up previous port forward setups if self.same_vnet: os.system( f"setsid socat tcp-listen:{self.dashboard_port},reuseaddr,fork tcp:{dashboard_address} &" ) os.system( f"setsid socat tcp-listen:{self.jupyter_port},reuseaddr,fork tcp:{jupyter_address} &" ) ### map additional ports for port in self.additional_ports: os.system( f"setsid socat tcp-listen:{self.port[1]},reuseaddr,fork tcp:{scheduler_ip}:{port[0]} &" ) else: scheduler_public_ip = self.compute_target.list_nodes()[0]["publicIpAddress"] scheduler_public_port = self.compute_target.list_nodes()[0]["port"] self.__print_message("scheduler_public_ip: {}".format(scheduler_public_ip)) self.__print_message( "scheduler_public_port: {}".format(scheduler_public_port) ) host_ip = "0.0.0.0" if self.is_in_ci: host_ip = socket.gethostbyname(self.hostname) cmd = ( "ssh -vvv -o StrictHostKeyChecking=no -N" f" -i {os.path.expanduser(self.admin_ssh_key)}" f" -L {host_ip}:{self.jupyter_port}:{scheduler_ip}:8888" f" -L {host_ip}:{self.dashboard_port}:{scheduler_ip}:8787" f" -L {host_ip}:{self.scheduler_port}:{scheduler_ip}:8786" ) for port in self.additional_ports: cmd += f" -L {host_ip}:{port[1]}:{scheduler_ip}:{port[0]}" cmd += f" {self.admin_username}@{scheduler_public_ip} -p {scheduler_public_port}" self.portforward_proc = subprocess.Popen( cmd.split(), universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) ### Starting thread to keep the SSH tunnel open on Windows portforward_logg = threading.Thread( target=self.__port_forward_logger, args=[self.portforward_proc] ) portforward_logg.start() @property def dashboard_link(self): """ Link to Dask dashboard. """ try: link = self.scheduler_info["dashboard_url"] except KeyError: return "" else: return link @property def jupyter_link(self): """ Link to JupyterLab on running on the headnode of the cluster. Set ``jupyter=True`` when creating the ``AzureMLCluster``. """ try: link = self.scheduler_info["jupyter_url"] except KeyError: return "" else: return link def _format_nodes(self, nodes, requested, use_gpu, n_gpus_per_node=None): if use_gpu: if nodes == requested: return f"{nodes}" else: return f"{nodes} / {requested}" else: if nodes == requested: return f"{nodes}" else: return f"{nodes} / {requested}" def _widget_status(self): ### reporting proper number of nodes vs workers in a multi-GPU worker scenario nodes = len(self.scheduler_info["workers"]) if self.use_gpu: nodes = int(nodes / self.n_gpus_per_node) if hasattr(self, "worker_spec"): requested = sum( 1 if "group" not in each else len(each["group"]) for each in self.worker_spec.values() ) elif hasattr(self, "nodes"): requested = len(self.nodes) else: requested = nodes nodes = self._format_nodes(nodes, requested, self.use_gpu, self.n_gpus_per_node) cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values()) cores_or_gpus = "Workers (GPUs)" if self.use_gpu else "Workers (vCPUs)" memory = ( sum( v["gpu"]["memory-total"][0] for v in self.scheduler_info["workers"].values() ) if self.use_gpu else sum(v["memory_limit"] for v in self.scheduler_info["workers"].values()) ) memory = format_bytes(memory) text = """ <div> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style> <table style="text-align: right;"> <tr> <th>Nodes</th> <td>%s</td></tr> <tr> <th>%s</th> <td>%s</td></tr> <tr> <th>Memory</th> <td>%s</td></tr> </table> </div> """ % ( nodes, cores_or_gpus, cores, memory, ) return text def _widget(self): """ Create IPython widget for display within a notebook """ try: return self._cached_widget except AttributeError: pass try: from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion except ImportError: self._cached_widget = None return None layout = Layout(width="150px") if self.dashboard_link: dashboard_link = ( '<p><b>Dashboard: </b><a href="%s" target="_blank">%s</a></p>\n' % (self.dashboard_link, self.dashboard_link) ) else: dashboard_link = "" if self.jupyter_link: jupyter_link = ( '<p><b>Jupyter: </b><a href="%s" target="_blank">%s</a></p>\n' % (self.jupyter_link, self.jupyter_link) ) else: jupyter_link = "" title = "<h2>%s</h2>" % self._cluster_class_name title = HTML(title) dashboard = HTML(dashboard_link) jupyter = HTML(jupyter_link) status = HTML(self._widget_status(), layout=Layout(min_width="150px")) if self._supports_scaling: request = IntText( self.initial_node_count, description="Nodes", layout=layout ) scale = Button(description="Scale", layout=layout) minimum = IntText(0, description="Minimum", layout=layout) maximum = IntText(0, description="Maximum", layout=layout) adapt = Button(description="Adapt", layout=layout) accordion = Accordion( [HBox([request, scale]), HBox([minimum, maximum, adapt])], layout=Layout(min_width="500px"), ) accordion.selected_index = None accordion.set_title(0, "Manual Scaling") accordion.set_title(1, "Adaptive Scaling") def adapt_cb(b): self.adapt(minimum=minimum.value, maximum=maximum.value) update() adapt.on_click(adapt_cb) def scale_cb(b): with log_errors(): n = request.value with suppress(AttributeError): self._adaptive.stop() self.scale(n) update() scale.on_click(scale_cb) else: accordion = HTML("") box = VBox([title, HBox([status, accordion]), jupyter, dashboard]) self._cached_widget = box def update(): self.close_when_disconnect() status.value = self._widget_status() pc = PeriodicCallback(update, 500) # , io_loop=self.loop) self.periodic_callbacks["cluster-repr"] = pc pc.start() return box def close_when_disconnect(self): status = self.run.get_status() if status == "Canceled" or status == "Completed" or status == "Failed": self.close() def scale(self, workers=1): """ Scale the cluster. Scales to a maximum of the workers available in the cluster. """ if workers <= 0: self.close() return count = len(self.workers_list) + 1 # one more worker in head node if count < workers: self.scale_up(workers - count) elif count > workers: self.scale_down(count - workers) else: self.__print_message(f"Number of workers: {workers}") # scale up def scale_up(self, workers=1): """ Scale up the number of workers. """ run_config = RunConfiguration() run_config.target = self.compute_target run_config.environment = self.environment_definition scheduler_ip = self.run.get_metrics()["scheduler"] args = [ f"--scheduler_ip_port={scheduler_ip}", f"--use_gpu={self.use_gpu}", f"--n_gpus_per_node={self.n_gpus_per_node}", f"--worker_death_timeout={self.worker_death_timeout}", ] child_run_config = ScriptRunConfig( source_directory=os.path.join(self.abs_path, "setup"), script="start_worker.py", arguments=args, run_config=run_config, ) for i in range(workers): child_run = self.run.submit_child(child_run_config, tags=self.tags) self.workers_list.append(child_run) # scale down def scale_down(self, workers=1): """ Scale down the number of workers. Scales to minimum of 1. """ for i in range(workers): if self.workers_list: child_run = self.workers_list.pop(0) # deactivate oldest workers child_run.complete() # complete() will mark the run "Complete", but won't kill the process child_run.cancel() else: self.__print_message("All scaled workers are removed.") # close cluster async def _close(self): if self.status == "closed": return while self.workers_list: child_run = self.workers_list.pop() child_run.complete() child_run.cancel() if self.run: self.run.complete() self.run.cancel() self.status = "closed" self.__print_message("Scheduler and workers are disconnected.") if self.portforward_proc is not None: ### STOP LOGGING SSH self.portforward_proc.terminate() self.end_logging = True ### REMOVE TEMP FILE if os.path.isfile(self.pub_key_file): os.remove(self.pub_key_file) if os.path.isfile(self.pri_key_file): os.remove(self.pri_key_file) if not self.compute_target_set: ### REMOVE COMPUTE TARGET self.__delete_compute_target() time.sleep(30) await super()._close() def close(self): """ Close the cluster. All Azure ML Runs corresponding to the scheduler and worker processes will be completed. The Azure ML Compute Target will return to its minimum number of nodes after its idle time before scaledown. """ return self.sync(self._close)
class Gateway(object): """A client for a Dask Gateway Server. Parameters ---------- address : str, optional The address to the gateway server. proxy_address : str, int, optional The address of the scheduler proxy server. If an int, it's used as the port, with the host/ip taken from ``address``. Provide a full address if a different host/ip should be used. auth : GatewayAuth, optional The authentication method to use. asynchronous : bool, optional If true, starts the client in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ def __init__(self, address=None, proxy_address=None, auth=None, asynchronous=False, loop=None): if address is None: address = dask.config.get("gateway.address") if address is None: raise ValueError( "No dask-gateway address provided or found in configuration") address = address.rstrip("/") if proxy_address is None: proxy_address = dask.config.get("gateway.proxy-address") if proxy_address is None: raise ValueError( "No dask-gateway proxy address provided or found in configuration" ) if isinstance(proxy_address, int): parsed = urlparse(address) proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address) elif isinstance(proxy_address, str): parsed = urlparse(proxy_address) proxy_netloc = parsed.netloc if parsed.netloc else proxy_address proxy_address = "gateway://%s" % proxy_netloc self.address = address self.proxy_address = proxy_address self._auth = get_auth(auth) self._cookie_jar = CookieJar() self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop self._loop_runner.start() if self._asynchronous: self._started = self._start() else: self.sync(self._start) async def _start(self): self._http_client = AsyncHTTPClient() def close(self): """Close this gateway client""" if not self.asynchronous: self._loop_runner.stop() def __del__(self): # __del__ is still called, even if __init__ failed. Only close if init # actually succeeded. if hasattr(self, "_started"): self.close() def __enter__(self): return self def __exit__(self, *args): self.close() def __await__(self): return self._started.__await__() async def __aenter__(self): await self._started return self async def __aexit__(self, typ, value, traceback): pass def __repr__(self): return "Gateway<%s>" % self.address @property def asynchronous(self): return (self._asynchronous or getattr(thread_state, "asynchronous", False) or (hasattr(self.loop, "_thread_identity") and self.loop._thread_identity == get_ident())) def sync(self, func, *args, **kwargs): if kwargs.pop("asynchronous", None) or self.asynchronous: callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: return sync(self.loop, func, *args, **kwargs) async def _fetch(self, req): try: self._cookie_jar.pre_request(req) resp = await self._http_client.fetch(req, raise_error=False) if resp.code == 401: context = self._auth.pre_request(req, resp) resp = await self._http_client.fetch(req, raise_error=False) self._auth.post_response(req, resp, context) self._cookie_jar.post_response(resp) if resp.error: if resp.code == 599: raise TimeoutError("Request timed out") else: try: msg = json.loads(resp.body)["error"] except Exception: msg = resp.body.decode() if resp.code in {404, 422}: raise ValueError(msg) elif resp.code == 409: raise GatewayClusterError(msg) elif resp.code == 500: raise GatewayServerError(msg) else: resp.rethrow() except HTTPError as exc: # Tornado 6 still raises these above with raise_error=False if exc.code == 599: raise TimeoutError("Request timed out") # Should never get here! raise return resp async def _clusters(self, status=None): if status is not None: if isinstance(status, (str, ClusterStatus)): status = [ClusterStatus._create(status)] else: status = [ClusterStatus._create(s) for s in status] query = "?status=" + ",".join(s.name for s in status) else: query = "" url = "%s/gateway/api/clusters/%s" % (self.address, query) req = HTTPRequest(url=url) resp = await self._fetch(req) return [ ClusterReport._from_json(self.address, self.proxy_address, r) for r in json.loads(resp.body).values() ] def list_clusters(self, status=None, **kwargs): """List clusters for this user. Parameters ---------- status : ClusterStatus, str, or list, optional The cluster status (or statuses) to select. Valid options are 'starting', 'started', 'running', 'stopping', 'stopped', 'failed'. By default selects active clusters ('starting', 'started', 'running'). Returns ------- clusters : list of ClusterReport """ return self.sync(self._clusters, status=status, **kwargs) async def _cluster_options(self): url = "%s/gateway/api/clusters/options" % self.address req = HTTPRequest(url=url, method="GET") resp = await self._fetch(req) data = json.loads(resp.body) return Options._from_spec(data["cluster_options"]) def cluster_options(self, **kwargs): """Get the available cluster configuration options. Returns ------- cluster_options : Options A dict of cluster options. """ return self.sync(self._cluster_options, **kwargs) async def _submit(self, cluster_options=None, **kwargs): url = "%s/gateway/api/clusters/" % self.address if cluster_options is not None: if not isinstance(cluster_options, Options): raise TypeError( "cluster_options must be an `Options`, got %r" % type(cluster_options).__name__) options = dict(cluster_options) options.update(kwargs) else: options = kwargs req = HTTPRequest( url=url, method="POST", body=json.dumps({"cluster_options": options}), headers=HTTPHeaders({"Content-type": "application/json"}), ) resp = await self._fetch(req) data = json.loads(resp.body) return data["name"] def submit(self, cluster_options=None, **kwargs): """Submit a new cluster to be started. This returns quickly with a ``cluster_name``, which can later be used to connect to the cluster. Parameters ---------- cluster_options : Options, optional An ``Options`` object describing the desired cluster configuration. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster_name : str The cluster name. """ return self.sync(self._submit, **kwargs) async def _cluster_report(self, cluster_name, wait=False): params = "?wait" if wait else "" url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name, params) req = HTTPRequest(url=url) resp = await self._fetch(req) return ClusterReport._from_json(self.address, self.proxy_address, json.loads(resp.body)) async def _connect(self, cluster_name): while True: try: report = await self._cluster_report(cluster_name, wait=True) except TimeoutError: # Timeout, ignore pass else: if report.status is ClusterStatus.RUNNING: return GatewayCluster( gateway=self, name=report.name, scheduler_address=report.scheduler_address, dashboard_link=report.dashboard_link, security=report.security, ) elif report.status is ClusterStatus.FAILED: raise GatewayClusterError( "Cluster %r failed to start, see logs for " "more information" % cluster_name) elif report.status is ClusterStatus.STOPPED: raise GatewayClusterError("Cluster %r is already stopped" % cluster_name) # Not started yet, try again later await gen.sleep(0.5) def connect(self, cluster_name, **kwargs): """Connect to a submitted cluster. Returns ------- cluster : GatewayCluster """ return self.sync(self._connect, cluster_name, **kwargs) async def _new_cluster(self, **kwargs): cluster_name = await self._submit(**kwargs) try: return await self._connect(cluster_name) except GatewayClusterError: raise except BaseException: # Ensure cluster is stopped on error await self._stop_cluster(cluster_name) raise def new_cluster(self, cluster_options=None, **kwargs): """Submit a new cluster to the gateway, and wait for it to be started. Same as calling ``submit`` and ``connect`` in one go. Parameters ---------- cluster_options : Options, optional An ``Options`` object describing the desired cluster configuration. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster : GatewayCluster """ return self.sync(self._new_cluster, cluster_options=cluster_options, **kwargs) async def _stop_cluster(self, cluster_name): url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name) req = HTTPRequest(url=url, method="DELETE") await self._fetch(req) def stop_cluster(self, cluster_name, **kwargs): """Stop a cluster. Parameters ---------- cluster_name : str The cluster name. """ return self.sync(self._stop_cluster, cluster_name, **kwargs) async def _scale_cluster(self, cluster_name, n): url = "%s/gateway/api/clusters/%s/workers" % (self.address, cluster_name) req = HTTPRequest( url=url, method="PUT", body=json.dumps({"worker_count": n}), headers=HTTPHeaders({"Content-type": "application/json"}), ) await self._fetch(req) def scale_cluster(self, cluster_name, n, **kwargs): """Scale a cluster to n workers. Parameters ---------- cluster_name : str The cluster name. n : int The number of workers to scale to. """ return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
class Gateway(object): """A client for a Dask Gateway Server. Parameters ---------- address : str, optional The address to the gateway server. auth : GatewayAuth, optional The authentication method to use. asynchronous : bool, optional If true, starts the client in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ def __init__(self, address=None, auth=None, asynchronous=False, loop=None): if address is None: address = dask.config.get("gateway.address") if address is None: raise ValueError( "No dask-gateway address provided or found in configuration") self.address = address.rstrip("/") self._auth = get_auth(auth) self._cookie_jar = CookieJar() self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop self._loop_runner.start() if self._asynchronous: self._started = self._start() else: self.sync(self._start) async def _start(self): self._http_client = AsyncHTTPClient() def close(self): """Close this gateway client""" if not self.asynchronous: self._loop_runner.stop() def __del__(self): # __del__ is still called, even if __init__ failed. Only close if init # actually succeeded. if hasattr(self, "_started"): self.close() def __enter__(self): return self def __exit__(self, *args): self.close() def __await__(self): return self._started.__await__() async def __aenter__(self): await self._started return self async def __aexit__(self, typ, value, traceback): pass def __repr__(self): return "Gateway<%s>" % self.address @property def asynchronous(self): return (self._asynchronous or getattr(thread_state, "asynchronous", False) or (hasattr(self.loop, "_thread_identity") and self.loop._thread_identity == get_ident())) def sync(self, func, *args, **kwargs): if kwargs.pop("asynchronous", None) or self.asynchronous: callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: return sync(self.loop, func, *args, **kwargs) async def _fetch(self, req, raise_error=True): self._cookie_jar.pre_request(req) resp = await self._http_client.fetch(req, raise_error=False) if resp.code == 401: context = self._auth.pre_request(req, resp) resp = await self._http_client.fetch(req, raise_error=False) self._auth.post_response(req, resp, context) self._cookie_jar.post_response(resp) if raise_error: resp.rethrow() return resp async def _clusters(self, status=None): if status is not None: if isinstance(status, (str, ClusterStatus)): status = [ClusterStatus._create(status)] else: status = [ClusterStatus._create(s) for s in status] query = "?status=" + ",".join(s.name for s in status) else: query = "" url = "%s/gateway/api/clusters/%s" % (self.address, query) req = HTTPRequest(url=url) resp = await self._fetch(req) return [ ClusterReport._from_json(self.address, r) for r in json.loads(resp.body).values() ] def list_clusters(self, status=None, **kwargs): """List clusters for this user. Parameters ---------- status : ClusterStatus, str, or list, optional The cluster status (or statuses) to select. Valid options are 'starting', 'started', 'running', 'stopping', 'stopped', 'failed'. By default selects active clusters ('starting', 'started', 'running'). Returns ------- clusters : list of ClusterReport """ return self.sync(self._clusters, status=status, **kwargs) async def _submit(self): url = "%s/gateway/api/clusters/" % self.address req = HTTPRequest( url=url, method="POST", body=json.dumps({}), headers=HTTPHeaders({"Content-type": "application/json"}), ) resp = await self._fetch(req) data = json.loads(resp.body) return data["name"] def submit(self, **kwargs): """Submit a new cluster to be started. This returns quickly with a ``cluster_name``, which can later be used to connect to the cluster. Returns ------- cluster_name : str The cluster name. """ return self.sync(self._submit, **kwargs) async def _cluster_report(self, cluster_name, wait=False): params = "?wait" if wait else "" url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name, params) req = HTTPRequest(url=url) resp = await self._fetch(req) return ClusterReport._from_json(self.address, json.loads(resp.body)) async def _connect(self, cluster_name): while True: try: report = await self._cluster_report(cluster_name, wait=True) except HTTPError as exc: if exc.code == 404: raise Exception("Unknown cluster %r" % cluster_name) elif exc.code == 599: # Timeout, ignore pass else: raise else: if report.status is ClusterStatus.RUNNING: return GatewayCluster(self, report) elif report.status is ClusterStatus.FAILED: raise Exception("Cluster %r failed to start, see logs for " "more information" % cluster_name) elif report.status is ClusterStatus.STOPPED: raise Exception("Cluster %r is already stopped" % cluster_name) # Not started yet, try again later await gen.sleep(0.5) def connect(self, cluster_name, **kwargs): """Connect to a submitted cluster. Returns ------- cluster : GatewayCluster """ return self.sync(self._connect, cluster_name, **kwargs) async def _new_cluster(self, **kwargs): cluster_name = await self._submit(**kwargs) try: return await self._connect(cluster_name) except BaseException: # Ensure cluster is stopped on error await self._stop_cluster(cluster_name) raise def new_cluster(self, **kwargs): """Submit a new cluster to the gateway, and wait for it to be started. Same as calling ``submit`` and ``connect`` in one go. Returns ------- cluster : GatewayCluster """ return self.sync(self._new_cluster, **kwargs) async def _stop_cluster(self, cluster_name): url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name) req = HTTPRequest(url=url, method="DELETE") await self._fetch(req) def stop_cluster(self, cluster_name, **kwargs): """Stop a cluster. Parameters ---------- cluster_name : str The cluster name. """ return self.sync(self._stop_cluster, cluster_name, **kwargs) async def _scale_cluster(self, cluster_name, n): url = "%s/gateway/api/clusters/%s/workers" % (self.address, cluster_name) req = HTTPRequest( url=url, method="PUT", body=json.dumps({"worker_count": n}), headers=HTTPHeaders({"Content-type": "application/json"}), ) try: await self._fetch(req) except HTTPError as exc: if exc.code == 409: raise Exception("Cluster %r is not running" % cluster_name) raise def scale_cluster(self, cluster_name, n, **kwargs): """Scale a cluster to n workers. Parameters ---------- cluster_name : str The cluster name. n : int The number of workers to scale to. """ return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
class Gateway(object): """A client for a Dask Gateway Server. Parameters ---------- address : str, optional The address to the gateway server. proxy_address : str, int, optional The address of the scheduler proxy server. Defaults to `address` if not provided. If an int, it's used as the port, with the host/ip taken from ``address``. Provide a full address if a different host/ip should be used. auth : GatewayAuth, optional The authentication method to use. asynchronous : bool, optional If true, starts the client in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ def __init__(self, address=None, proxy_address=None, auth=None, asynchronous=False, loop=None): if address is None: address = format_template(dask.config.get("gateway.address")) if address is None: raise ValueError( "No dask-gateway address provided or found in configuration") address = address.rstrip("/") public_address = format_template( dask.config.get("gateway.public-address")) if public_address is None: public_address = address else: public_address = public_address.rstrip("/") if proxy_address is None: proxy_address = format_template( dask.config.get("gateway.proxy-address")) if proxy_address is None: parsed = urlparse(address) if parsed.netloc: if parsed.port is None: proxy_port = { "http": 80, "https": 443 }.get(parsed.scheme, 8786) else: proxy_port = parsed.port proxy_netloc = "%s:%d" % (parsed.hostname, proxy_port) else: proxy_netloc = proxy_address elif isinstance(proxy_address, int): parsed = urlparse(address) proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address) elif isinstance(proxy_address, str): parsed = urlparse(proxy_address) proxy_netloc = parsed.netloc if parsed.netloc else proxy_address proxy_address = "gateway://%s" % proxy_netloc scheme = urlparse(address).scheme self._request_kwargs = _get_default_request_kwargs(scheme) self.address = address self._public_address = public_address self.proxy_address = proxy_address self.auth = get_auth(auth) self._session = None self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self._loop_runner.start() @property def loop(self): return self._loop_runner.loop @property def asynchronous(self): return self._asynchronous def sync(self, func, *args, **kwargs): if self.asynchronous: return func(*args, **kwargs) else: future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop.asyncio_loop) try: return future.result() except BaseException: future.cancel() raise def close(self): """Close the gateway client""" if self.asynchronous: return self._cleanup() elif self.loop.asyncio_loop.is_running(): self.sync(self._cleanup) self._loop_runner.stop() async def _cleanup(self): if self._session is not None: await self._session.close() self._session = None async def __aenter__(self): return self async def __aexit__(self, typ, value, traceback): await self._cleanup() def __enter__(self): return self def __exit__(self, *args): self.close() def __del__(self): if (not self.asynchronous and hasattr(self, "_loop_runner") and not sys.is_finalizing()): self.close() def __repr__(self): return "Gateway<%s>" % self.address async def _request(self, method, url, json=None): if self._session is None: self._session = aiohttp.ClientSession() session = self._session resp = await session.request(method, url, json=json, **self._request_kwargs) if resp.status == 401: headers, context = self.auth.pre_request(resp) resp = await session.request(method, url, json=json, headers=headers, **self._request_kwargs) self.auth.post_response(resp, context) if resp.status >= 400: try: msg = await resp.json() msg = msg["error"] except Exception: msg = await resp.text() if resp.status in {404, 422}: raise ValueError(msg) elif resp.status == 409: raise GatewayClusterError(msg) elif resp.status == 500: raise GatewayServerError(msg) else: resp.raise_for_status() else: return resp async def _clusters(self, status=None): if status is not None: if isinstance(status, (str, ClusterStatus)): status = [ClusterStatus._create(status)] else: status = [ClusterStatus._create(s) for s in status] query = "?status=" + ",".join(s.name for s in status) else: query = "" url = "%s/api/v1/clusters/%s" % (self.address, query) resp = await self._request("GET", url) data = await resp.json() return [ ClusterReport._from_json(self._public_address, self.proxy_address, r) for r in data.values() ] def list_clusters(self, status=None, **kwargs): """List clusters for this user. Parameters ---------- status : ClusterStatus, str, or list, optional The cluster status (or statuses) to select. Valid options are 'pending', 'running', 'stopping', 'stopped', 'failed'. By default selects active clusters ('pending', 'running'). Returns ------- clusters : list of ClusterReport """ return self.sync(self._clusters, status=status, **kwargs) def get_cluster(self, cluster_name, **kwargs): """Get information about a specific cluster. Parameters ---------- cluster_name : str The cluster name. Returns ------- report : ClusterReport """ return self.sync(self._cluster_report, cluster_name, **kwargs) async def _get_versions(self): url = "%s/api/version" % self.address resp = await self._request("GET", url) server_info = await resp.json() from . import __version__ return { "server": server_info, "client": { "version": __version__ }, } def get_versions(self): """Return version info for the server and client Returns ------- version_info : dict """ return self.sync(self._get_versions) def _config_cluster_options(self): opts = dask.config.get("gateway.cluster.options") return {k: format_template(v) for k, v in opts.items()} async def _cluster_options(self, use_local_defaults=True): url = "%s/api/v1/options" % self.address resp = await self._request("GET", url) data = await resp.json() options = Options._from_spec(data["cluster_options"]) if use_local_defaults: options.update(self._config_cluster_options()) return options def cluster_options(self, use_local_defaults=True, **kwargs): """Get the available cluster configuration options. Parameters ---------- use_local_defaults : bool, optional Whether to use any default options from the local configuration. Default is True, set to False to use only the server-side defaults. Returns ------- cluster_options : dask_gateway.options.Options A dict of cluster options. """ return self.sync(self._cluster_options, use_local_defaults=use_local_defaults, **kwargs) async def _submit(self, cluster_options=None, **kwargs): url = "%s/api/v1/clusters/" % self.address if cluster_options is not None: if not isinstance(cluster_options, Options): raise TypeError( "cluster_options must be an `Options`, got %r" % type(cluster_options).__name__) options = dict(cluster_options) options.update(kwargs) else: options = self._config_cluster_options() options.update(kwargs) resp = await self._request("POST", url, json={"cluster_options": options}) data = await resp.json() return data["name"] def submit(self, cluster_options=None, **kwargs): """Submit a new cluster to be started. This returns quickly with a ``cluster_name``, which can later be used to connect to the cluster. Parameters ---------- cluster_options : dask_gateway.options.Options, optional An ``Options`` object describing the desired cluster configuration. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster_name : str The cluster name. """ return self.sync(self._submit, cluster_options=cluster_options, **kwargs) async def _cluster_report(self, cluster_name, wait=False): params = "?wait" if wait else "" url = "%s/api/v1/clusters/%s%s" % (self.address, cluster_name, params) resp = await self._request("GET", url) data = await resp.json() return ClusterReport._from_json(self._public_address, self.proxy_address, data) async def _wait_for_start(self, cluster_name): while True: try: report = await self._cluster_report(cluster_name, wait=True) except TimeoutError: # Timeout, ignore pass else: if report.status is ClusterStatus.RUNNING: return report elif report.status is ClusterStatus.FAILED: raise GatewayClusterError( "Cluster %r failed to start, see logs for " "more information" % cluster_name) elif report.status is ClusterStatus.STOPPED: raise GatewayClusterError("Cluster %r is already stopped" % cluster_name) # Not started yet, try again later await asyncio.sleep(0.5) def connect(self, cluster_name, shutdown_on_close=False): """Connect to a submitted cluster. Parameters ---------- cluster_name : str The cluster to connect to. shutdown_on_close : bool, optional If True, the cluster will be automatically shutdown on close. Default is False. Returns ------- cluster : GatewayCluster """ return GatewayCluster.from_name( cluster_name, shutdown_on_close=shutdown_on_close, address=self.address, proxy_address=self.proxy_address, auth=self.auth, asynchronous=self.asynchronous, loop=self.loop, ) def new_cluster(self, cluster_options=None, shutdown_on_close=True, **kwargs): """Submit a new cluster to the gateway, and wait for it to be started. Same as calling ``submit`` and ``connect`` in one go. Parameters ---------- cluster_options : dask_gateway.options.Options, optional An ``Options`` object describing the desired cluster configuration. shutdown_on_close : bool, optional If True (default), the cluster will be automatically shutdown on close. Set to False to have cluster persist until explicitly shutdown. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster : GatewayCluster """ return GatewayCluster( address=self.address, proxy_address=self.proxy_address, auth=self.auth, asynchronous=self.asynchronous, loop=self.loop, cluster_options=cluster_options, shutdown_on_close=shutdown_on_close, **kwargs, ) async def _stop_cluster(self, cluster_name): url = "%s/api/v1/clusters/%s" % (self.address, cluster_name) await self._request("DELETE", url) def stop_cluster(self, cluster_name, **kwargs): """Stop a cluster. Parameters ---------- cluster_name : str The cluster name. """ return self.sync(self._stop_cluster, cluster_name, **kwargs) async def _scale_cluster(self, cluster_name, n): url = "%s/api/v1/clusters/%s/scale" % (self.address, cluster_name) await self._request("POST", url, json={"count": n}) def scale_cluster(self, cluster_name, n, **kwargs): """Scale a cluster to n workers. Parameters ---------- cluster_name : str The cluster name. n : int The number of workers to scale to. """ return self.sync(self._scale_cluster, cluster_name, n, **kwargs) async def _adapt_cluster(self, cluster_name, minimum=None, maximum=None, active=True): await self._request( "POST", "%s/api/v1/clusters/%s/adapt" % (self.address, cluster_name), json={ "minimum": minimum, "maximum": maximum, "active": active }, ) def adapt_cluster(self, cluster_name, minimum=None, maximum=None, active=True, **kwargs): """Configure adaptive scaling for a cluster. Parameters ---------- cluster_name : str The cluster name. minimum : int, optional The minimum number of workers to scale to. Defaults to 0. maximum : int, optional The maximum number of workers to scale to. Defaults to infinity. active : bool, optional If ``True`` (default), adaptive scaling is activated. Set to ``False`` to deactivate adaptive scaling. """ return self.sync( self._adapt_cluster, cluster_name, minimum=minimum, maximum=maximum, active=active, **kwargs, )
class AzureMLCluster(Cluster): """ Deploy a Dask cluster using Azure ML This creates a dask scheduler and workers on an Azure ML Compute Target. Parameters ---------- workspace: azureml.core.Workspace (required) Azure ML Workspace - see https://aka.ms/azureml/workspace compute_target: azureml.core.ComputeTarget (required) Azure ML Compute Target - see https://aka.ms/azureml/computetarget environment_definition: azureml.core.Environment (required) Azure ML Environment - see https://aka.ms/azureml/environments experiment_name: str (optional) The name of the Azure ML Experiment used to control the cluster. Defaults to ``dask-cloudprovider``. initial_node_count: int (optional) The initial number of nodes for the Dask Cluster. Defaults to ``1``. jupyter: bool (optional) Flag to start JupyterLab session on the headnode of the cluster. Defaults to ``False``. jupyter_port: int (optional) Port on headnode to use for hosting JupyterLab session. Defaults to ``9000``. dashboard_port: int (optional) Port on headnode to use for hosting Dask dashboard. Defaults to ``9001``. scheduler_port: int (optional) Port to map the scheduler port to via SSH-tunnel if machine not on the same VNET. Defaults to ``9002``. additional_ports: list[tuple[int, int]] (optional) Additional ports to forward. This requires a list of tuples where the first element is the port to open on the headnode while the second element is the port to map to or forward via the SSH-tunnel. Defaults to ``[]``. admin_username: str (optional) Username of the admin account for the AzureML Compute. Required for runs that are not on the same VNET. Defaults to empty string. Throws Exception if machine not on the same VNET. Defaults to ``""``. admin_ssh_key: str (optional) Location of the SSH secret key used when creating the AzureML Compute. The key should be passwordless if run from a Jupyter notebook. The ``id_rsa`` file needs to have 0700 permissions set. Required for runs that are not on the same VNET. Defaults to empty string. Throws Exception if machine not on the same VNET. Defaults to ``""``. datastores: List[str] (optional) List of Azure ML Datastores to be mounted on the headnode - see https://aka.ms/azureml/data and https://aka.ms/azureml/datastores. Defaults to ``[]``. To mount all datastores in the workspace, set to ``[ws.datastores[datastore] for datastore in ws.datastores]``. asynchronous: bool (optional) Flag to run jobs asynchronously. **kwargs: dict Additional keyword arguments. """ def __init__( self, workspace, compute_target, environment_definition, experiment_name=None, initial_node_count=None, jupyter=None, jupyter_port=None, dashboard_port=None, scheduler_port=None, scheduler_idle_timeout=None, worker_death_timeout=None, additional_ports=None, admin_username=None, admin_ssh_key=None, datastores=None, code_store=None, asynchronous=False, **kwargs, ): ### REQUIRED PARAMETERS self.workspace = workspace self.compute_target = compute_target self.environment_definition = environment_definition ### EXPERIMENT DEFINITION self.experiment_name = experiment_name ### ENVIRONMENT AND VARIABLES self.initial_node_count = initial_node_count ### GPU RUN INFO self.workspace_vm_sizes = AmlCompute.supported_vmsizes(self.workspace) self.workspace_vm_sizes = [ (e["name"].lower(), e["gpus"]) for e in self.workspace_vm_sizes ] self.workspace_vm_sizes = dict(self.workspace_vm_sizes) self.compute_target_vm_size = self.compute_target.serialize()["properties"][ "status" ]["vmSize"].lower() self.n_gpus_per_node = self.workspace_vm_sizes[self.compute_target_vm_size] self.use_gpu = True if self.n_gpus_per_node > 0 else False ### JUPYTER AND PORT FORWARDING self.jupyter = jupyter self.jupyter_port = jupyter_port self.dashboard_port = dashboard_port self.scheduler_port = scheduler_port self.scheduler_idle_timeout = scheduler_idle_timeout self.worker_death_timeout = worker_death_timeout if additional_ports is not None: if type(additional_ports) != list: error_message = ( f"The additional_ports parameter is of {type(additional_ports)}" " type but needs to be a list of int tuples." " Check the documentation." ) logger.exception(error_message) raise TypeError(error_message) if len(additional_ports) > 0: if type(additional_ports[0]) != tuple: error_message = ( f"The additional_ports elements are of {type(additional_ports[0])}" " type but needs to be a list of int tuples." " Check the documentation." ) raise TypeError(error_message) ### check if all elements are tuples of length two and int type all_correct = True for el in additional_ports: if type(el) != tuple or len(el) != 2: all_correct = False break if (type(el[0]), type(el[1])) != (int, int): all_correct = False break if not all_correct: error_message = ( f"At least one of the elements of the additional_ports parameter" " is wrong. Make sure it is a list of int tuples." " Check the documentation." ) raise TypeError(error_message) self.additional_ports = additional_ports self.admin_username = admin_username self.admin_ssh_key = admin_ssh_key self.scheduler_ip_port = ( None ### INIT FOR HOLDING THE ADDRESS FOR THE SCHEDULER ) ### DATASTORES self.datastores = datastores ### FUTURE EXTENSIONS self.kwargs = kwargs ### RUNNING IN MATRIX OR LOCAL self.same_vnet = None ### GET RUNNING LOOP self._loop_runner = LoopRunner(loop=None, asynchronous=asynchronous) self.loop = self._loop_runner.loop self.abs_path = pathlib.Path(__file__).parent.absolute() ### INITIALIZE CLUSTER super().__init__(asynchronous=asynchronous) if not self.asynchronous: self._loop_runner.start() self.sync(self.__get_defaults) self.sync(self.__create_cluster) async def __get_defaults(self): self.config = dask.config.get("cloudprovider.azure", {}) if self.experiment_name is None: self.experiment_name = self.config.get("experiment_name") if self.initial_node_count is None: self.initial_node_count = self.config.get("initial_node_count") if self.jupyter is None: self.jupyter = self.config.get("jupyter") if self.jupyter_port is None: self.jupyter_port = self.config.get("jupyter_port") if self.dashboard_port is None: self.dashboard_port = self.config.get("dashboard_port") if self.scheduler_port is None: self.scheduler_port = self.config.get("scheduler_port") if self.scheduler_idle_timeout is None: self.scheduler_idle_timeout = self.config.get("scheduler_idle_timeout") if self.worker_death_timeout is None: self.worker_death_timeout = self.config.get("worker_death_timeout") if self.additional_ports is None: self.additional_ports = self.config.get("additional_ports") if self.admin_username is None: self.admin_username = self.config.get("admin_username") if self.admin_ssh_key is None: self.admin_ssh_key = self.config.get("admin_ssh_key") if self.datastores is None: self.datastores = self.config.get("datastores") ### PARAMETERS TO START THE CLUSTER self.scheduler_params = {} self.worker_params = {} ### scheduler and worker parameters self.scheduler_params["--jupyter"] = self.jupyter self.scheduler_params["--scheduler_idle_timeout"] = self.scheduler_idle_timeout self.worker_params["--worker_death_timeout"] = self.worker_death_timeout if self.use_gpu: self.scheduler_params["--use_gpu"] = True self.scheduler_params["--n_gpus_per_node"] = self.n_gpus_per_node self.worker_params["--use_gpu"] = True self.worker_params["--n_gpus_per_node"] = self.n_gpus_per_node ### CLUSTER PARAMS self.max_nodes = self.compute_target.serialize()["properties"]["properties"][ "scaleSettings" ]["maxNodeCount"] self.scheduler_ip_port = None self.workers_list = [] self.URLs = {} ### SANITY CHECKS ###-----> initial node count if self.initial_node_count > self.max_nodes: self.initial_node_count = self.max_nodes def __print_message(self, msg, length=80, filler="#", pre_post=""): logger.info(msg) print(f"{pre_post} {msg} {pre_post}".center(length, filler)) async def __check_if_scheduler_ip_reachable(self): """ Private method to determine if running in the cloud within the same VNET and the scheduler node is reachable """ try: ip, port = self.scheduler_ip_port.split(":") socket.create_connection((ip, port), 10) self.same_vnet = True self.__print_message("On the same VNET") logger.info("On the same VNET") except socket.timeout as e: self.__print_message("Not on the same VNET") logger.info("On the same VNET") self.same_vnet = False except ConnectionRefusedError as e: logger.info(e) pass def __prepare_rpc_connection_to_headnode(self): if not self.same_vnet: if self.admin_username == "" or self.admin_ssh_key == "": message = "Your machine is not at the same VNET as the cluster. " message += "You need to set admin_username and admin_ssh_key. Check documentation." logger.exception(message) raise Exception(message) else: uri = f"{socket.gethostname()}:{self.scheduler_port}" logger.info(f"Local connection: {uri}") return uri else: return self.run.get_metrics()["scheduler"] async def __create_cluster(self): # set up environment self.__print_message("Setting up cluster") # submit run self.__print_message("Submitting the experiment") exp = Experiment(self.workspace, self.experiment_name) estimator = Estimator( os.path.join(self.abs_path, "setup"), compute_target=self.compute_target, entry_script="start_scheduler.py", environment_definition=self.environment_definition, script_params=self.scheduler_params, node_count=1, ### start only scheduler distributed_training=MpiConfiguration(), use_docker=True, inputs=self.datastores, ) run = exp.submit(estimator) self.__print_message("Waiting for scheduler node's IP") while ( run.get_status() != "Canceled" and run.get_status() != "Failed" and "scheduler" not in run.get_metrics() ): print(".", end="") logger.info("Scheduler not ready") time.sleep(5) if run.get_status() == "Canceled" or run.get_status() == "Failed": logger.exception("Failed to start the AzureML cluster") raise Exception("Failed to start the AzureML cluster.") print("\n\n") ### SET FLAGS self.scheduler_ip_port = run.get_metrics()["scheduler"] self.worker_params["--scheduler_ip_port"] = self.scheduler_ip_port self.__print_message(f'Scheduler: {run.get_metrics()["scheduler"]}') self.run = run logger.info(f'Scheduler: {run.get_metrics()["scheduler"]}') ### CHECK IF ON THE SAME VNET while self.same_vnet is None: await self.sync(self.__check_if_scheduler_ip_reachable) time.sleep(1) ### REQUIRED BY dask.distributed.deploy.cluster.Cluster _scheduler = self.__prepare_rpc_connection_to_headnode() self.scheduler_comm = rpc(_scheduler) await self.sync(self.__setup_port_forwarding) await self.sync(super()._start) await self.sync(self.__update_links) self.__print_message("Connections established") self.__print_message(f"Scaling to {self.initial_node_count} workers") if self.initial_node_count > 1: self.scale( self.initial_node_count ) # LOGIC TO KEEP PROPER TRACK OF WORKERS IN `scale` self.__print_message(f"Scaling is done") async def __update_links(self): hostname = socket.gethostname() location = self.workspace.get_details()["location"] token = self.run.get_metrics()["token"] if self.same_vnet: self.scheduler_info[ "dashboard_url" ] = f"https://{hostname}-{self.dashboard_port}.{location}.instances.azureml.net/status" self.scheduler_info[ "jupyter_url" ] = f"https://{hostname}-{self.jupyter_port}.{location}.instances.azureml.net/lab?token={token}" else: self.scheduler_info[ "dashboard_url" ] = f"http://{hostname}:{self.dashboard_port}" self.scheduler_info[ "jupyter_url" ] = f"http://{hostname}:{self.jupyter_port}/?token={token}" logger.info(f'Dashboard URL: {self.scheduler_info["dashboard_url"]}') logger.info(f'Jupyter URL: {self.scheduler_info["jupyter_url"]}') async def __setup_port_forwarding(self): dashboard_address = self.run.get_metrics()["dashboard"] jupyter_address = self.run.get_metrics()["jupyter"] scheduler_ip = self.run.get_metrics()["scheduler"].split(":")[0] if self.same_vnet: os.system( f"killall socat" ) # kill all socat processes - cleans up previous port forward setups os.system( f"setsid socat tcp-listen:{self.dashboard_port},reuseaddr,fork tcp:{dashboard_address} &" ) os.system( f"setsid socat tcp-listen:{self.jupyter_port},reuseaddr,fork tcp:{jupyter_address} &" ) ### map additional ports for port in self.additional_ports: os.system( f"setsid socat tcp-listen:{self.port[1]},reuseaddr,fork tcp:{scheduler_ip}:{port[0]} &" ) else: scheduler_public_ip = self.compute_target.list_nodes()[0]["publicIpAddress"] scheduler_public_port = self.compute_target.list_nodes()[0]["port"] cmd = ( "ssh -vvv -o StrictHostKeyChecking=no -N" f" -i {self.admin_ssh_key}" f" -L 0.0.0.0:{self.jupyter_port}:{scheduler_ip}:8888" f" -L 0.0.0.0:{self.dashboard_port}:{scheduler_ip}:8787" f" -L 0.0.0.0:{self.scheduler_port}:{scheduler_ip}:8786" ) for port in self.additional_ports: cmd += f" -L 0.0.0.0:{port[1]}:{scheduler_ip}:{port[0]}" cmd += f" {self.admin_username}@{scheduler_public_ip} -p {scheduler_public_port}" portforward_log = open("portforward_out_log.txt", "w") portforward_proc = subprocess.Popen( cmd.split(), universal_newlines=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) @property def dashboard_link(self): """ Link to Dask dashboard. """ try: link = self.scheduler_info["dashboard_url"] except KeyError: return "" else: return link @property def jupyter_link(self): """ Link to JupyterLab on running on the headnode of the cluster. Set ``jupyter=True`` when creating the ``AzureMLCluster``. """ try: link = self.scheduler_info["jupyter_url"] except KeyError: return "" else: return link def _format_nodes(self, nodes, requested, use_gpu, n_gpus_per_node=None): if use_gpu: if nodes == requested: return f"{nodes}" else: return f"{nodes} / {requested}" else: if nodes == requested: return f"{nodes}" else: return f"{nodes} / {requested}" def _widget_status(self): ### reporting proper number of nodes vs workers in a multi-GPU worker scenario nodes = len(self.scheduler_info["workers"]) if self.use_gpu: nodes = int(nodes / self.n_gpus_per_node) if hasattr(self, "worker_spec"): requested = sum( 1 if "group" not in each else len(each["group"]) for each in self.worker_spec.values() ) elif hasattr(self, "nodes"): requested = len(self.nodes) else: requested = nodes nodes = self._format_nodes(nodes, requested, self.use_gpu, self.n_gpus_per_node) cores = sum(v["nthreads"] for v in self.scheduler_info["workers"].values()) cores_or_gpus = "Workers (GPUs)" if self.use_gpu else "Workers (vCPUs)" memory = ( sum( v["gpu"]["memory-total"][0] for v in self.scheduler_info["workers"].values() ) if self.use_gpu else sum(v["memory_limit"] for v in self.scheduler_info["workers"].values()) ) memory = format_bytes(memory) text = """ <div> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style> <table style="text-align: right;"> <tr> <th>Nodes</th> <td>%s</td></tr> <tr> <th>%s</th> <td>%s</td></tr> <tr> <th>Memory</th> <td>%s</td></tr> </table> </div> """ % ( nodes, cores_or_gpus, cores, memory, ) return text def _widget(self): """ Create IPython widget for display within a notebook """ try: return self._cached_widget except AttributeError: pass try: from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion except ImportError: self._cached_widget = None return None layout = Layout(width="150px") if self.dashboard_link: dashboard_link = ( '<p><b>Dashboard: </b><a href="%s" target="_blank">%s</a></p>\n' % (self.dashboard_link, self.dashboard_link,) ) else: dashboard_link = "" if self.jupyter_link: jupyter_link = ( '<p><b>Jupyter: </b><a href="%s" target="_blank">%s</a></p>\n' % (self.jupyter_link, self.jupyter_link,) ) else: jupyter_link = "" title = "<h2>%s</h2>" % self._cluster_class_name title = HTML(title) dashboard = HTML(dashboard_link) jupyter = HTML(jupyter_link) status = HTML(self._widget_status(), layout=Layout(min_width="150px")) if self._supports_scaling: request = IntText( self.initial_node_count, description="Nodes", layout=layout ) scale = Button(description="Scale", layout=layout) minimum = IntText(0, description="Minimum", layout=layout) maximum = IntText(0, description="Maximum", layout=layout) adapt = Button(description="Adapt", layout=layout) accordion = Accordion( [HBox([request, scale]), HBox([minimum, maximum, adapt])], layout=Layout(min_width="500px"), ) accordion.selected_index = None accordion.set_title(0, "Manual Scaling") accordion.set_title(1, "Adaptive Scaling") def adapt_cb(b): self.adapt(minimum=minimum.value, maximum=maximum.value) update() adapt.on_click(adapt_cb) def scale_cb(b): with log_errors(): n = request.value with ignoring(AttributeError): self._adaptive.stop() self.scale(n) update() scale.on_click(scale_cb) else: accordion = HTML("") box = VBox([title, HBox([status, accordion]), jupyter, dashboard]) self._cached_widget = box def update(): self.close_when_disconnect() status.value = self._widget_status() pc = PeriodicCallback(update, 500, io_loop=self.loop) self.periodic_callbacks["cluster-repr"] = pc pc.start() return box def close_when_disconnect(self): if ( self.run.get_status() == "Canceled" or self.run.get_status() == "Completed" or self.run.get_status() == "Failed" ): self.scale_down(len(self.workers_list)) def scale(self, workers=1): """ Scale the cluster. Scales to a maximum of the workers available in the cluster. """ if workers <= 0: self.close() return count = len(self.workers_list) + 1 # one more worker in head node if count < workers: self.scale_up(workers - count) elif count > workers: self.scale_down(count - workers) else: self.__print_message(f"Number of workers: {workers}") # scale up def scale_up(self, workers=1): """ Scale up the number of workers. """ run_config = RunConfiguration() run_config.target = self.compute_target run_config.environment = self.environment_definition scheduler_ip = self.run.get_metrics()["scheduler"] args = [ f"--scheduler_ip_port={scheduler_ip}", f"--use_gpu={self.use_gpu}", f"--n_gpus_per_node={self.n_gpus_per_node}", f"--worker_death_timeout={self.worker_death_timeout}", ] child_run_config = ScriptRunConfig( source_directory=os.path.join(self.abs_path, "setup"), script="start_worker.py", arguments=args, run_config=run_config, ) for i in range(workers): child_run = self.run.submit_child(child_run_config) self.workers_list.append(child_run) # scale down def scale_down(self, workers=1): """ Scale down the number of workers. Scales to minimum of 1. """ for i in range(workers): if self.workers_list: child_run = self.workers_list.pop(0) # deactive oldest workers child_run.complete() # complete() will mark the run "Complete", but won't kill the process child_run.cancel() else: self.__print_message("All scaled workers are removed.") # close cluster async def _close(self): if self.status == "closed": return while self.workers_list: child_run = self.workers_list.pop() child_run.complete() child_run.cancel() if self.run: self.run.complete() self.run.cancel() await super()._close() self.status = "closed" self.__print_message("Scheduler and workers are disconnected.") def close(self): """ Close the cluster. All Azure ML Runs corresponding to the scheduler and worker processes will be completed. The Azure ML Compute Target will return to its minimum number of nodes after its idle time before scaledown. """ return self.sync(self._close)
class HelmCluster(Cluster): """Connect to a Dask cluster deployed via the Helm Chart. This cluster manager connects to an existing Dask deployment that was created by the Dask Helm Chart. Enabling you to perform basic cluster actions such as scaling and log retrieval. Parameters ---------- release_name: str Name of the helm release to connect to. namespace: str (optional) Namespace in which to launch the workers. Defaults to current namespace if available or "default" port_forward_cluster_ip: bool (optional) If the chart uses ClusterIP type services, forward the ports locally. If you are using ``HelmCluster`` from the Jupyter session that was installed by the helm chart this should be ``False``. If you are running it locally it should be ``True``. auth: List[ClusterAuth] (optional) Configuration methods to attempt in order. Defaults to ``[InCluster(), KubeConfig()]``. scheduler_name: str (optional) Name of the Dask scheduler deployment in the current release. Defaults to "scheduler". worker_name: str (optional) Name of the Dask worker deployment in the current release. Defaults to "worker". **kwargs: dict Additional keyword arguments to pass to Cluster Examples -------- >>> from dask_kubernetes import HelmCluster >>> cluster = HelmCluster(release_name="myhelmrelease") You can then resize the cluster with the scale method >>> cluster.scale(10) You can pass this cluster directly to a Dask client >>> from dask.distributed import Client >>> client = Client(cluster) You can also access cluster logs >>> cluster.get_logs() See Also -------- HelmCluster.scale HelmCluster.logs """ def __init__( self, release_name=None, auth=ClusterAuth.DEFAULT, namespace=None, port_forward_cluster_ip=False, loop=None, asynchronous=False, scheduler_name="scheduler", worker_name="worker", ): self.release_name = release_name self.namespace = namespace or _namespace_default() self.check_helm_dependency() status = subprocess.run( ["helm", "-n", self.namespace, "status", self.release_name], capture_output=True, encoding="utf-8", ) if status.returncode != 0: raise RuntimeError(f"No such helm release {self.release_name}.") self.auth = auth self.namespace self.core_api = None self.scheduler_comm = None self.port_forward_cluster_ip = port_forward_cluster_ip self._supports_scaling = True self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop self.scheduler_name = scheduler_name self.worker_name = worker_name super().__init__(asynchronous=asynchronous) if not self.asynchronous: self._loop_runner.start() self.sync(self._start) @staticmethod def check_helm_dependency(): if shutil.which("helm") is None: raise RuntimeError( "Missing dependency helm. " "Please install helm following the instructions for your OS. " "https://helm.sh/docs/intro/install/") async def _start(self): await ClusterAuth.load_first(self.auth) self.core_api = kubernetes.client.CoreV1Api() self.apps_api = kubernetes.client.AppsV1Api() self.scheduler_comm = rpc(await self._get_scheduler_address()) await super()._start() async def _get_scheduler_address(self): service_name = f"{self.release_name}-{self.scheduler_name}" service = await self.core_api.read_namespaced_service( service_name, self.namespace) [port] = [ port.port for port in service.spec.ports if port.name == service_name ] if service.spec.type == "LoadBalancer": lb = service.status.load_balancer.ingress[0] host = lb.hostname or lb.ip return f"tcp://{host}:{port}" elif service.spec.type == "NodePort": nodes = await self.core_api.list_node() host = nodes.items[0].status.addresses[0].address return f"tcp://{host}:{port}" elif service.spec.type == "ClusterIP": if self.port_forward_cluster_ip: warnings.warn( f""" Sorry we do not currently support local port forwarding. Please port-forward the service locally yourself with the following command. kubectl port-forward --namespace {self.namespace} svc/{service_name} {port}:{port} & """ ) # FIXME Handle this port forward here with the kubernetes library return f"tcp://localhost:{port}" return f"tcp://{service.spec.cluster_ip}:{port}" raise RuntimeError("Unable to determine scheduler address.") async def _wait_for_workers(self): while True: n_workers = len(self.scheduler_info["workers"]) deployment = await self.apps_api.read_namespaced_deployment( name=f"{self.release_name}-{self.worker_name}", namespace=self.namespace) deployment_replicas = deployment.spec.replicas if n_workers == deployment_replicas: return else: await asyncio.sleep(0.2) def get_logs(self): """Get logs for Dask scheduler and workers. Examples -------- >>> cluster.get_logs() {'testdask-scheduler-5c8ffb6b7b-sjgrg': ..., 'testdask-worker-64c8b78cc-992z8': ..., 'testdask-worker-64c8b78cc-hzpdc': ..., 'testdask-worker-64c8b78cc-wbk4f': ...} Each log will be a string of all logs for that container. To view it is recommeded that you print each log. >>> print(cluster.get_logs()["testdask-scheduler-5c8ffb6b7b-sjgrg"]) ... distributed.scheduler - INFO - ----------------------------------------------- distributed.scheduler - INFO - Clear task state distributed.scheduler - INFO - Scheduler at: tcp://10.1.6.131:8786 distributed.scheduler - INFO - dashboard at: :8787 ... """ return self.sync(self._get_logs) async def _get_logs(self): logs = Logs() pods = await self.core_api.list_namespaced_pod( namespace=self.namespace, label_selector=f"release={self.release_name},app=dask", ) for pod in pods.items: if "scheduler" in pod.metadata.name or "worker" in pod.metadata.name: logs[pod.metadata.name] = Log( await self.core_api.read_namespaced_pod_log( pod.metadata.name, pod.metadata.namespace)) return logs def __await__(self): async def _(): if self.status == "created": await self._start() elif self.status == "running": await self._wait_for_workers() return self return _().__await__() def scale(self, n_workers): """Scale cluster to n workers. This sets the Dask worker deployment size to the requested number. Workers will not be terminated gracefull so be sure to only scale down when all futures have been retrieved by the client and the cluster is idle. Examples -------- >>> cluster HelmCluster('tcp://localhost:8786', workers=3, threads=18, memory=18.72 GB) >>> cluster.scale(4) >>> cluster HelmCluster('tcp://localhost:8786', workers=4, threads=24, memory=24.96 GB) """ return self.sync(self._scale, n_workers) async def _scale(self, n_workers): await self.apps_api.patch_namespaced_deployment( name=f"{self.release_name}-{self.worker_name}", namespace=self.namespace, body={"spec": { "replicas": n_workers, }}, ) def adapt(self, *args, **kwargs): """Turn on adaptivity (Not recommended).""" raise NotImplementedError( "It is not recommended to run ``HelmCluster`` in adaptive mode." "When scaling down workers the decision on which worker to remove is left to Kubernetes, which" "will not necessarily remove the same worker that Dask would choose. This may result in lost futures and" "recalculation. It is recommended to manage scaling yourself with the ``HelmCluster.scale`` method." ) async def _adapt(self, *args, **kwargs): return super().adapt(*args, **kwargs)
class HelmCluster(Cluster): """Connect to a Dask cluster deployed via the Helm Chart. This cluster manager connects to an existing Dask deployment that was created by the Dask Helm Chart. Enabling you to perform basic cluster actions such as scaling and log retrieval. Parameters ---------- release_name: str Name of the helm release to connect to. namespace: str (optional) Namespace in which to launch the workers. Defaults to current namespace if available or "default" port_forward_cluster_ip: bool (optional) If the chart uses ClusterIP type services, forward the ports locally. If you are using ``HelmCluster`` from the Jupyter session that was installed by the helm chart this should be ``False``. If you are running it locally it should be the port you are forwarding to ``<port>``. auth: List[ClusterAuth] (optional) Configuration methods to attempt in order. Defaults to ``[InCluster(), KubeConfig()]``. scheduler_name: str (optional) Name of the Dask scheduler deployment in the current release. Defaults to "scheduler". worker_name: str (optional) Name of the Dask worker deployment in the current release. Defaults to "worker". node_host: str (optional) A node address. Can be provided in case scheduler service type is ``NodePort`` and you want to manually specify which node to connect to. node_port: int (optional) A node address. Can be provided in case scheduler service type is ``NodePort`` and you want to manually specify which port to connect to. **kwargs: dict Additional keyword arguments to pass to Cluster. Examples -------- >>> from dask_kubernetes import HelmCluster >>> cluster = HelmCluster(release_name="myhelmrelease") You can then resize the cluster with the scale method >>> cluster.scale(10) You can pass this cluster directly to a Dask client >>> from dask.distributed import Client >>> client = Client(cluster) You can also access cluster logs >>> cluster.get_logs() See Also -------- HelmCluster.scale HelmCluster.logs """ def __init__( self, release_name=None, auth=ClusterAuth.DEFAULT, namespace=None, port_forward_cluster_ip=False, loop=None, asynchronous=False, scheduler_name="scheduler", worker_name="worker", node_host=None, node_port=None, **kwargs, ): self.release_name = release_name self.namespace = namespace or namespace_default() check_dependency("helm") check_dependency("kubectl") status = subprocess.run( ["helm", "-n", self.namespace, "status", self.release_name], capture_output=True, encoding="utf-8", ) if status.returncode != 0: raise RuntimeError(f"No such helm release {self.release_name}.") self.auth = auth self.namespace self.core_api = None self.scheduler_comm = None self.port_forward_cluster_ip = port_forward_cluster_ip self._supports_scaling = True self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop self.scheduler_name = scheduler_name self.worker_name = worker_name self.node_host = node_host self.node_port = node_port super().__init__(asynchronous=asynchronous, **kwargs) if not self.asynchronous: self._loop_runner.start() self.sync(self._start) async def _start(self): await ClusterAuth.load_first(self.auth) self.core_api = kubernetes.client.CoreV1Api() self.apps_api = kubernetes.client.AppsV1Api() self.scheduler_comm = rpc(await self._get_scheduler_address()) await super()._start() async def _get_scheduler_address(self): # Get the chart name chart = subprocess.check_output( [ "helm", "-n", self.namespace, "list", "-f", self.release_name, "--output", "json", ], encoding="utf-8", ) chart = json.loads(chart)[0]["chart"] # extract name from {{.Chart.Name }}-{{ .Chart.Version }} chart_name = "-".join(chart.split("-")[:-1]) # Follow the spec in the dask/dask helm chart self.chart_name = (f"{chart_name}-" if chart_name not in self.release_name else "") service_name = f"{self.release_name}-{self.chart_name}{self.scheduler_name}" service = await self.core_api.read_namespaced_service( service_name, self.namespace) address = await get_external_address_for_scheduler_service( self.core_api, service, port_forward_cluster_ip=self.port_forward_cluster_ip) if address is None: raise RuntimeError("Unable to determine scheduler address.") return address async def _wait_for_workers(self): while True: n_workers = len(self.scheduler_info["workers"]) deployment = await self.apps_api.read_namespaced_deployment( name=f"{self.release_name}-{self.chart_name}{self.worker_name}", namespace=self.namespace, ) deployment_replicas = deployment.spec.replicas if n_workers == deployment_replicas: return else: await asyncio.sleep(0.2) def get_logs(self): """Get logs for Dask scheduler and workers. Examples -------- >>> cluster.get_logs() {'testdask-scheduler-5c8ffb6b7b-sjgrg': ..., 'testdask-worker-64c8b78cc-992z8': ..., 'testdask-worker-64c8b78cc-hzpdc': ..., 'testdask-worker-64c8b78cc-wbk4f': ...} Each log will be a string of all logs for that container. To view it is recommeded that you print each log. >>> print(cluster.get_logs()["testdask-scheduler-5c8ffb6b7b-sjgrg"]) ... distributed.scheduler - INFO - ----------------------------------------------- distributed.scheduler - INFO - Clear task state distributed.scheduler - INFO - Scheduler at: tcp://10.1.6.131:8786 distributed.scheduler - INFO - dashboard at: :8787 ... """ return self.sync(self._get_logs) async def _get_logs(self): logs = Logs() pods = await self.core_api.list_namespaced_pod( namespace=self.namespace, label_selector=f"release={self.release_name},app=dask", ) for pod in pods.items: if "scheduler" in pod.metadata.name or "worker" in pod.metadata.name: try: if pod.status.phase != "Running": raise ValueError( f"Cannot get logs for pod with status {pod.status.phase}.", ) log = Log(await self.core_api.read_namespaced_pod_log( pod.metadata.name, pod.metadata.namespace)) except (ValueError, kubernetes.client.exceptions.ApiException): log = Log(f"Cannot find logs. Pod is {pod.status.phase}.") logs[pod.metadata.name] = log return logs def __await__(self): async def _(): if self.status == Status.created: await self._start() elif self.status == Status.running: await self._wait_for_workers() return self return _().__await__() def scale(self, n_workers): """Scale cluster to n workers. This sets the Dask worker deployment size to the requested number. Workers will not be terminated gracefull so be sure to only scale down when all futures have been retrieved by the client and the cluster is idle. Examples -------- >>> cluster HelmCluster('tcp://localhost:8786', workers=3, threads=18, memory=18.72 GB) >>> cluster.scale(4) >>> cluster HelmCluster('tcp://localhost:8786', workers=4, threads=24, memory=24.96 GB) """ return self.sync(self._scale, n_workers) async def _scale(self, n_workers): await self.apps_api.patch_namespaced_deployment( name=f"{self.release_name}-{self.chart_name}{self.worker_name}", namespace=self.namespace, body={"spec": { "replicas": n_workers, }}, ) def adapt(self, *args, **kwargs): """Turn on adaptivity (Not recommended).""" raise NotImplementedError( "It is not recommended to run ``HelmCluster`` in adaptive mode." "When scaling down workers the decision on which worker to remove is left to Kubernetes, which" "will not necessarily remove the same worker that Dask would choose. This may result in lost futures and" "recalculation. It is recommended to manage scaling yourself with the ``HelmCluster.scale`` method." ) async def _adapt(self, *args, **kwargs): return super().adapt(*args, **kwargs)