def test_close_loop_sync(with_own_loop): loop_runner = loop = None # Setup simple cluster with one threaded worker. # Complex setup is not required here since we test only IO loop teardown. cluster_params = dict(n_workers=1, dashboard_address=None, processes=False) loops_before = LoopRunner._all_loops.copy() # Start own loop or use current thread's one. if with_own_loop: loop_runner = LoopRunner() loop_runner.start() loop = loop_runner.loop with LocalCluster(loop=loop, **cluster_params) as cluster: with Client(cluster, loop=loop) as client: client.run(max, 1, 2) # own loop must be explicitly stopped. if with_own_loop: loop_runner.stop() # Internal loops registry must the same as before cluster running. # This means loop runners in LocalCluster and Client correctly stopped. # See LoopRunner._stop_unlocked(). assert loops_before == LoopRunner._all_loops
async def test_loop_runner_gen(): runner = LoopRunner(asynchronous=True) assert runner.loop is IOLoop.current() assert not runner.is_started() await asyncio.sleep(0.01) runner.start() assert runner.is_started() await asyncio.sleep(0.01) runner.stop() assert not runner.is_started() await asyncio.sleep(0.01)
def test_loop_runner_gen(): runner = LoopRunner(asynchronous=True) assert runner.loop is IOLoop.current() assert not runner.is_started() yield gen.sleep(0.01) runner.start() assert runner.is_started() yield gen.sleep(0.01) runner.stop() assert not runner.is_started() yield gen.sleep(0.01)
class ContextCluster(Cluster): def __init__(self, asynchronous=False, loop=None): self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop super().__init__(asynchronous=asynchronous) if not self.asynchronous: self._loop_runner.start() self.sync(self._start) def __enter__(self): if self.status != "running": raise ValueError( f"Expected status 'running', found '{self.status}'") return self def __exit__(self, typ, value, traceback): self.close() self._loop_runner.stop()
def test_two_loop_runners(loop_in_thread): # Loop runners tied to the same loop should cooperate # ABCCBA loop = IOLoop() a = LoopRunner(loop=loop) b = LoopRunner(loop=loop) assert_not_running(loop) a.start() assert_running(loop) c = LoopRunner(loop=loop) b.start() assert_running(loop) c.start() assert_running(loop) c.stop() assert_running(loop) b.stop() assert_running(loop) a.stop() assert_not_running(loop) # ABCABC loop = IOLoop() a = LoopRunner(loop=loop) b = LoopRunner(loop=loop) assert_not_running(loop) a.start() assert_running(loop) b.start() assert_running(loop) c = LoopRunner(loop=loop) c.start() assert_running(loop) a.stop() assert_running(loop) b.stop() assert_running(loop) c.stop() assert_not_running(loop) # Explicit loop, already started a = LoopRunner(loop=loop_in_thread) b = LoopRunner(loop=loop_in_thread) assert_running(loop_in_thread) a.start() assert_running(loop_in_thread) b.start() assert_running(loop_in_thread) a.stop() assert_running(loop_in_thread) b.stop() assert_running(loop_in_thread)
def test_loop_runner(loop_in_thread): # Implicit loop loop = IOLoop() loop.make_current() runner = LoopRunner() assert runner.loop not in (loop, loop_in_thread) assert not runner.is_started() assert_not_running(runner.loop) runner.start() assert runner.is_started() assert_running(runner.loop) runner.stop() assert not runner.is_started() assert_not_running(runner.loop) # Explicit loop loop = IOLoop() runner = LoopRunner(loop=loop) assert runner.loop is loop assert not runner.is_started() assert_not_running(loop) runner.start() assert runner.is_started() assert_running(loop) runner.stop() assert not runner.is_started() assert_not_running(loop) # Explicit loop, already started runner = LoopRunner(loop=loop_in_thread) assert not runner.is_started() assert_running(loop_in_thread) runner.start() assert runner.is_started() assert_running(loop_in_thread) runner.stop() assert not runner.is_started() assert_running(loop_in_thread) # Implicit loop, asynchronous=True loop = IOLoop() loop.make_current() runner = LoopRunner(asynchronous=True) assert runner.loop is loop assert not runner.is_started() assert_not_running(runner.loop) runner.start() assert runner.is_started() assert_not_running(runner.loop) runner.stop() assert not runner.is_started() assert_not_running(runner.loop) # Explicit loop, asynchronous=True loop = IOLoop() runner = LoopRunner(loop=loop, asynchronous=True) assert runner.loop is loop assert not runner.is_started() assert_not_running(runner.loop) runner.start() assert runner.is_started() assert_not_running(runner.loop) runner.stop() assert not runner.is_started() assert_not_running(runner.loop)
class YarnCluster(object): """Start a Dask cluster on YARN. You can define default values for this in Dask's ``yarn.yaml`` configuration file. See http://docs.dask.org/en/latest/configuration.html for more information. Parameters ---------- environment : str, optional The Python environment to use. Can be one of the following: - A path to an archived Python environment - A path to a conda environment, specified as `conda:///...` - A path to a virtual environment, specified as `venv:///...` - A path to a python executable, specifed as `python:///...` Note that if not an archive, the paths specified must be valid on all nodes in the cluster. n_workers : int, optional The number of workers to initially start. worker_vcores : int, optional The number of virtual cores to allocate per worker. worker_memory : str, optional The amount of memory to allocate per worker. Accepts a unit suffix (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest MiB. worker_restarts : int, optional The maximum number of worker restarts to allow before failing the application. Default is unlimited. worker_env : dict, optional A mapping of environment variables to their values. These will be set in the worker containers before starting the dask workers. scheduler_vcores : int, optional The number of virtual cores to allocate per scheduler. scheduler_memory : str, optional The amount of memory to allocate to the scheduler. Accepts a unit suffix (e.g. '2 GiB' or '4096 MiB'). Will be rounded up to the nearest MiB. deploy_mode : {'remote', 'local'}, optional The deploy mode to use. If ``'remote'``, the scheduler will be deployed in a YARN container. If ``'local'``, the scheduler will run locally, which can be nice for debugging. Default is ``'remote'``. name : str, optional The application name. queue : str, optional The queue to deploy to. tags : sequence, optional A set of strings to use as tags for this application. user : str, optional The user to submit the application on behalf of. Default is the current user - submitting as a different user requires user permissions, see the YARN documentation for more information. host : str, optional Host address on which the scheduler will listen. Only used if ``deploy_mode='local'``. Defaults to ``'0.0.0.0'``. port : int, optional The port on which the scheduler will listen. Only used if ``deploy_mode='local'``. Defaults to ``0`` for a random port. dashboard_address : str Address on which to the dashboard server will listen. Only used if ``deploy_mode='local'``. Defaults to ':0' for a random port. skein_client : skein.Client, optional The ``skein.Client`` to use. If not provided, one will be started. asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. Examples -------- >>> cluster = YarnCluster(environment='my-env.tar.gz', ...) >>> cluster.scale(10) """ def __init__( self, environment=None, n_workers=None, worker_vcores=None, worker_memory=None, worker_restarts=None, worker_env=None, scheduler_vcores=None, scheduler_memory=None, deploy_mode=None, name=None, queue=None, tags=None, user=None, host=None, port=None, dashboard_address=None, skein_client=None, asynchronous=False, loop=None, ): spec = _make_specification( environment=environment, n_workers=n_workers, worker_vcores=worker_vcores, worker_memory=worker_memory, worker_restarts=worker_restarts, worker_env=worker_env, scheduler_vcores=scheduler_vcores, scheduler_memory=scheduler_memory, deploy_mode=deploy_mode, name=name, queue=queue, tags=tags, user=user, ) self._init_common( spec=spec, host=host, port=port, dashboard_address=dashboard_address, asynchronous=asynchronous, loop=loop, skein_client=skein_client, ) @classmethod def from_specification(cls, spec, skein_client=None, asynchronous=False, loop=None): """Start a dask cluster from a skein specification. Parameters ---------- spec : skein.ApplicationSpec, dict, or filename The application specification to use. Must define at least one service: ``'dask.worker'``. If no ``'dask.scheduler'`` service is defined, a scheduler will be started locally. skein_client : skein.Client, optional The ``skein.Client`` to use. If not provided, one will be started. asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ self = super(YarnCluster, cls).__new__(cls) if isinstance(spec, dict): spec = skein.ApplicationSpec.from_dict(spec) elif isinstance(spec, str): spec = skein.ApplicationSpec.from_file(spec) elif not isinstance(spec, skein.ApplicationSpec): raise TypeError("spec must be an ApplicationSpec, dict, or path, " "got %r" % type(spec).__name__) if "dask.worker" not in spec.services: raise ValueError( "Provided Skein specification must include a 'dask.worker' service" ) self._init_common(spec=spec, asynchronous=asynchronous, loop=loop, skein_client=skein_client) return self @classmethod def from_current(cls, asynchronous=False, loop=None): """Connect to an existing ``YarnCluster`` from inside the cluster. Parameters ---------- asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. Returns ------- YarnCluster """ self = super(YarnCluster, cls).__new__(cls) app_id = os.environ.get("DASK_APPLICATION_ID", None) app_address = os.environ.get("DASK_APPMASTER_ADDRESS", None) security_dir = os.environ.get("DASK_SECURITY_CREDENTIALS", None) if app_id is not None and app_address is not None: security = (None if security_dir is None else skein.Security.from_directory(security_dir)) app = skein.ApplicationClient(app_address, app_id, security=security) else: app = skein.ApplicationClient.from_current() self._init_common(application_client=app, asynchronous=asynchronous, loop=loop) return self @classmethod def from_application_id(cls, app_id, skein_client=None, asynchronous=False, loop=None): """Connect to an existing ``YarnCluster`` with a given application id. Parameters ---------- app_id : str The existing cluster's application id. skein_client : skein.Client The ``skein.Client`` to use. If not provided, one will be started. asynchronous : bool, optional If true, starts the cluster in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. Returns ------- YarnCluster """ self = super(YarnCluster, cls).__new__(cls) skein_client = _get_skein_client(skein_client) app = skein_client.connect(app_id) self._init_common( application_client=app, asynchronous=asynchronous, loop=loop, skein_client=skein_client, ) return self def _init_common( self, spec=None, application_client=None, host=None, port=None, dashboard_address=None, asynchronous=False, loop=None, skein_client=None, ): self.spec = spec self.application_client = application_client self._scheduler_kwargs = _make_scheduler_kwargs( host=host, port=port, dashboard_address=dashboard_address, ) self._scheduler = None self.scheduler_info = {} self._requested = set() self.scheduler_comm = None self._watch_worker_status_task = None self._start_task = None self._stop_task = None self._finalizer = None self._adaptive = None self._adaptive_options = {} self._skein_client = skein_client self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self._loop_runner.start() if not self.asynchronous: self._sync(self._start_internal()) def _start_cluster(self): """Start the cluster and initialize state""" skein_client = _get_skein_client(self._skein_client) if "dask.scheduler" not in self.spec.services: # deploy_mode == 'local' scheduler_address = self._scheduler.address for k in ["dashboard", "bokeh"]: if k in self._scheduler.services: dashboard_port = self._scheduler.services[k].port dashboard_host = urlparse(scheduler_address).hostname dashboard_address = "http://%s:%d" % ( dashboard_host, dashboard_port, ) break else: dashboard_address = None with submit_and_handle_failures(skein_client, self.spec) as app: app.kv["dask.scheduler"] = scheduler_address.encode() if dashboard_address is not None: app.kv["dask.dashboard"] = dashboard_address.encode() else: # deploy_mode == 'remote' with submit_and_handle_failures(skein_client, self.spec) as app: scheduler_address = app.kv.wait("dask.scheduler").decode() dashboard_address = app.kv.get("dask.dashboard") if dashboard_address is not None: dashboard_address = dashboard_address.decode() # Ensure application gets cleaned up self._finalizer = weakref.finalize(self, app.shutdown) self.scheduler_address = scheduler_address self._dashboard_address = dashboard_address self.application_client = app def _connect_existing(self): spec = self.application_client.get_specification() if "dask.worker" not in spec.services: raise ValueError("%r is not a valid dask cluster" % self.app_id) scheduler_address = self.application_client.kv.wait( "dask.scheduler").decode() dashboard_address = self.application_client.kv.get("dask.dashboard") if dashboard_address is not None: dashboard_address = dashboard_address.decode() self.spec = spec self.scheduler_address = scheduler_address self._dashboard_address = dashboard_address async def _start_internal(self): if self._start_task is None: self._start_task = asyncio.ensure_future(self._start_async()) try: await self._start_task except BaseException: # On exception, cleanup await self._stop_internal() raise return self async def _start_async(self): if self.spec is not None: # Start a new cluster if "dask.scheduler" not in self.spec.services: self._scheduler = Scheduler( loop=self.loop, **self._scheduler_kwargs, ) await self._scheduler else: self._scheduler = None await self.loop.run_in_executor(None, self._start_cluster) else: # Connect to an existing cluster await self.loop.run_in_executor(None, self._connect_existing) self.scheduler_comm = rpc(self.scheduler_address) comm = None try: comm = await self.scheduler_comm.live_comm() await comm.write({"op": "subscribe_worker_status"}) self.scheduler_info = await comm.read() workers = self.scheduler_info.get("workers", {}) self._requested.update(w["name"] for w in workers.values()) self._watch_worker_status_task = asyncio.ensure_future( self._watch_worker_status(comm)) except Exception: if comm is not None: await comm.close() async def _stop_internal(self, status="SUCCEEDED", diagnostics=None): if self._stop_task is None: self._stop_task = asyncio.ensure_future( self._stop_async(status=status, diagnostics=diagnostics)) await self._stop_task async def _stop_async(self, status="SUCCEEDED", diagnostics=None): if self._start_task is not None: if not self._start_task.done(): # We're still starting, cancel task await cancel_task(self._start_task) self._start_task = None if self._adaptive is not None: self._adaptive.stop() if self._watch_worker_status_task is not None: await cancel_task(self._watch_worker_status_task) self._watch_worker_status_task = None if self.scheduler_comm is not None: self.scheduler_comm.close_rpc() self.scheduler_comm = None await self.loop.run_in_executor( None, lambda: self._stop_sync(status=status, diagnostics=diagnostics)) if self._scheduler is not None: await self._scheduler.close() self._scheduler = None def _stop_sync(self, status="SUCCEEDED", diagnostics=None): if self._finalizer is not None and self._finalizer.peek() is not None: self.application_client.shutdown(status=status, diagnostics=diagnostics) self._finalizer.detach() # don't run the finalizer later self._finalizer = None def __await__(self): return self.__aenter__().__await__() async def __aenter__(self): return await self._start_internal() async def __aexit__(self, typ, value, traceback): await self._stop_internal() def __enter__(self): return self def __exit__(self, *args): self.close() def __repr__(self): return "YarnCluster<%s>" % self.app_id @property def loop(self): return self._loop_runner.loop @property def app_id(self): return self.application_client.id @property def asynchronous(self): return self._asynchronous def _sync(self, task): if self.asynchronous: return task future = asyncio.run_coroutine_threadsafe(task, self.loop.asyncio_loop) try: return future.result() except BaseException: future.cancel() raise @cached_property def dashboard_link(self): """Link to the dask dashboard. None if dashboard isn't running""" if self._dashboard_address is None: return None dashboard = urlparse(self._dashboard_address) return format_dashboard_link(dashboard.hostname, dashboard.port) def shutdown(self, status="SUCCEEDED", diagnostics=None): """Shutdown the application. Parameters ---------- status : {'SUCCEEDED', 'FAILED', 'KILLED'}, optional The yarn application exit status. diagnostics : str, optional The application exit message, usually used for diagnosing failures. Can be seen in the YARN Web UI for completed applications under "diagnostics". If not provided, a default will be used. """ if self.asynchronous: return self._stop_internal(status=status, diagnostics=diagnostics) if self.loop.asyncio_loop.is_running() and not sys.is_finalizing(): self._sync( self._stop_internal(status=status, diagnostics=diagnostics)) else: # Always run this! self._stop_sync(status=status, diagnostics=diagnostics) self._loop_runner.stop() def close(self, **kwargs): """Close this cluster. An alias for ``shutdown``. See Also -------- shutdown """ return self.shutdown(**kwargs) def __del__(self): if not hasattr(self, "_loop_runner"): return if self.asynchronous: # No del for async mode return self.close() @property def _observed(self): return {w["name"] for w in self.scheduler_info["workers"].values()} async def _workers(self): return await self.loop.run_in_executor( None, lambda: self.application_client.get_containers(services= ["dask.worker"]), ) def workers(self): """A list of all currently running worker containers.""" return self._sync(self._workers()) async def _logs(self, scheduler=True, workers=True): logs = Logs() if scheduler: slogs = await self.scheduler_comm.logs() logs["Scheduler"] = Log("\n".join(line for level, line in slogs)) if workers: d = await self.scheduler_comm.worker_logs(workers=workers) for k, v in d.items(): logs[k] = Log("\n".join(line for level, line in v)) return logs def logs(self, scheduler=True, workers=True): """Return logs for the scheduler and/or workers Parameters ---------- scheduler : boolean, optional Whether or not to collect logs for the scheduler workers : boolean or iterable, optional A list of worker addresses to select. Defaults to all workers if ``True`` or no workers if ``False`` Returns ------- logs : dict A dictionary of name -> logs. """ return self._sync(self._logs(scheduler=scheduler, workers=workers)) async def _scale_up(self, n): if n > len(self._requested): containers = await self.loop.run_in_executor( None, lambda: self.application_client.scale("dask.worker", n), ) self._requested.update(c.id for c in containers) async def _scale_down(self, workers): self._requested.difference_update(workers) await self.scheduler_comm.retire_workers(names=list(workers)) def _kill_containers(): for c in workers: try: self.application_client.kill_container(c) except ValueError: pass await self.loop.run_in_executor( None, _kill_containers, ) async def _scale(self, n): if self._adaptive is not None: self._adaptive.stop() if n >= len(self._requested): return await self._scale_up(n) else: n_to_delete = len(self._requested) - n pending = list(self._requested - self._observed) running = list(self._observed.difference(pending)) to_close = pending[:n_to_delete] to_close.extend(running[:n_to_delete - len(to_close)]) await self._scale_down(to_close) self._requested.difference_update(to_close) def scale(self, n): """Scale cluster to n workers. Parameters ---------- n : int Target number of workers Examples -------- >>> cluster.scale(10) # scale cluster to ten workers """ return self._sync(self._scale(n)) def adapt(self, minimum=0, maximum=math.inf, interval="1s", wait_count=3, target_duration="5s", **kwargs): """Turn on adaptivity This scales Dask clusters automatically based on scheduler activity. Parameters ---------- minimum : int, optional Minimum number of workers. Defaults to ``0``. maximum : int, optional Maximum number of workers. Defaults to ``inf``. interval : timedelta or str, optional Time between worker add/remove recommendations. wait_count : int, optional Number of consecutive times that a worker should be suggested for removal before we remove it. target_duration : timedelta or str, optional Amount of time we want a computation to take. This affects how aggressively we scale up. **kwargs : Additional parameters to pass to ``distributed.Scheduler.workers_to_close``. Examples -------- >>> cluster.adapt(minimum=0, maximum=10) """ if self._adaptive is not None: self._adaptive.stop() self._adaptive_options.update( minimum=minimum, maximum=maximum, interval=interval, wait_count=wait_count, target_duration=target_duration, **kwargs, ) self._adaptive = Adaptive(self, **self._adaptive_options) async def _watch_worker_status(self, comm): # We don't want to hold on to a ref to self, otherwise this will # leave a dangling reference and prevent garbage collection. ref_self = weakref.ref(self) self = None try: while True: try: msgs = await comm.read() except OSError: break try: self = ref_self() if self is None: break for op, msg in msgs: if op == "add": workers = msg.pop("workers") self.scheduler_info["workers"].update(workers) self.scheduler_info.update(msg) self._requested.update(w["name"] for w in workers.values()) elif op == "remove": del self.scheduler_info["workers"][msg] self._requested.discard(msg) if hasattr(self, "_status_widget"): self._status_widget.value = self._widget_status() finally: self = None finally: await comm.close() def _widget_status(self): try: workers = self.scheduler_info["workers"] except KeyError: return None else: n_workers = len(workers) cores = sum(w["nthreads"] for w in workers.values()) memory = sum(w["memory_limit"] for w in workers.values()) return _widget_status_template % (n_workers, cores, format_bytes(memory)) def _widget(self): """ Create IPython widget for display within a notebook """ try: return self._cached_widget except AttributeError: pass if self.asynchronous: return None try: from ipywidgets import Layout, VBox, HBox, IntText, Button, HTML, Accordion except ImportError: self._cached_widget = None return None layout = Layout(width="150px") title = HTML("<h2>YarnCluster</h2>") status = HTML(self._widget_status(), layout=Layout(min_width="150px")) request = IntText(0, description="Workers", layout=layout) scale = Button(description="Scale", layout=layout) minimum = IntText(0, description="Minimum", layout=layout) maximum = IntText(0, description="Maximum", layout=layout) adapt = Button(description="Adapt", layout=layout) accordion = Accordion( [HBox([request, scale]), HBox([minimum, maximum, adapt])], layout=Layout(min_width="500px"), ) accordion.selected_index = None accordion.set_title(0, "Manual Scaling") accordion.set_title(1, "Adaptive Scaling") @adapt.on_click def adapt_cb(b): self.adapt(minimum=minimum.value, maximum=maximum.value) @scale.on_click def scale_cb(b): with log_errors(): self.scale(request.value) app_id = HTML("<p><b>Application ID: </b>{0}</p>".format(self.app_id)) elements = [title, HBox([status, accordion]), app_id] if self.dashboard_link is not None: link = HTML( '<p><b>Dashboard: </b><a href="{0}" target="_blank">{0}' "</a></p>\n".format(self.dashboard_link)) elements.append(link) self._cached_widget = box = VBox(elements) self._status_widget = status return box def _ipython_display_(self, **kwargs): widget = self._widget() if widget is not None: return widget._ipython_display_(**kwargs) else: from IPython.display import display data = {"text/plain": repr(self), "text/html": self._repr_html_()} display(data, raw=True) def _repr_html_(self): if self.dashboard_link is not None: dashboard = "<a href='{0}' target='_blank'>{0}</a>".format( self.dashboard_link) else: dashboard = "Not Available" return ( "<div style='background-color: #f2f2f2; display: inline-block; " "padding: 10px; border: 1px solid #999999;'>\n" " <h3>YarnCluster</h3>\n" " <ul>\n" " <li><b>Application ID: </b>{app_id}\n" " <li><b>Dashboard: </b>{dashboard}\n" " </ul>\n" "</div>\n").format(app_id=self.app_id, dashboard=dashboard)
class Gateway(object): """A client for a Dask Gateway Server. Parameters ---------- address : str, optional The address to the gateway server. proxy_address : str, int, optional The address of the scheduler proxy server. If an int, it's used as the port, with the host/ip taken from ``address``. Provide a full address if a different host/ip should be used. auth : GatewayAuth, optional The authentication method to use. asynchronous : bool, optional If true, starts the client in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ def __init__(self, address=None, proxy_address=None, auth=None, asynchronous=False, loop=None): if address is None: address = dask.config.get("gateway.address") if address is None: raise ValueError( "No dask-gateway address provided or found in configuration") address = address.rstrip("/") if proxy_address is None: proxy_address = dask.config.get("gateway.proxy-address") if proxy_address is None: raise ValueError( "No dask-gateway proxy address provided or found in configuration" ) if isinstance(proxy_address, int): parsed = urlparse(address) proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address) elif isinstance(proxy_address, str): parsed = urlparse(proxy_address) proxy_netloc = parsed.netloc if parsed.netloc else proxy_address proxy_address = "gateway://%s" % proxy_netloc self.address = address self.proxy_address = proxy_address self._auth = get_auth(auth) self._cookie_jar = CookieJar() self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop self._loop_runner.start() if self._asynchronous: self._started = self._start() else: self.sync(self._start) async def _start(self): self._http_client = AsyncHTTPClient() def close(self): """Close this gateway client""" if not self.asynchronous: self._loop_runner.stop() def __del__(self): # __del__ is still called, even if __init__ failed. Only close if init # actually succeeded. if hasattr(self, "_started"): self.close() def __enter__(self): return self def __exit__(self, *args): self.close() def __await__(self): return self._started.__await__() async def __aenter__(self): await self._started return self async def __aexit__(self, typ, value, traceback): pass def __repr__(self): return "Gateway<%s>" % self.address @property def asynchronous(self): return (self._asynchronous or getattr(thread_state, "asynchronous", False) or (hasattr(self.loop, "_thread_identity") and self.loop._thread_identity == get_ident())) def sync(self, func, *args, **kwargs): if kwargs.pop("asynchronous", None) or self.asynchronous: callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: return sync(self.loop, func, *args, **kwargs) async def _fetch(self, req): try: self._cookie_jar.pre_request(req) resp = await self._http_client.fetch(req, raise_error=False) if resp.code == 401: context = self._auth.pre_request(req, resp) resp = await self._http_client.fetch(req, raise_error=False) self._auth.post_response(req, resp, context) self._cookie_jar.post_response(resp) if resp.error: if resp.code == 599: raise TimeoutError("Request timed out") else: try: msg = json.loads(resp.body)["error"] except Exception: msg = resp.body.decode() if resp.code in {404, 422}: raise ValueError(msg) elif resp.code == 409: raise GatewayClusterError(msg) elif resp.code == 500: raise GatewayServerError(msg) else: resp.rethrow() except HTTPError as exc: # Tornado 6 still raises these above with raise_error=False if exc.code == 599: raise TimeoutError("Request timed out") # Should never get here! raise return resp async def _clusters(self, status=None): if status is not None: if isinstance(status, (str, ClusterStatus)): status = [ClusterStatus._create(status)] else: status = [ClusterStatus._create(s) for s in status] query = "?status=" + ",".join(s.name for s in status) else: query = "" url = "%s/gateway/api/clusters/%s" % (self.address, query) req = HTTPRequest(url=url) resp = await self._fetch(req) return [ ClusterReport._from_json(self.address, self.proxy_address, r) for r in json.loads(resp.body).values() ] def list_clusters(self, status=None, **kwargs): """List clusters for this user. Parameters ---------- status : ClusterStatus, str, or list, optional The cluster status (or statuses) to select. Valid options are 'starting', 'started', 'running', 'stopping', 'stopped', 'failed'. By default selects active clusters ('starting', 'started', 'running'). Returns ------- clusters : list of ClusterReport """ return self.sync(self._clusters, status=status, **kwargs) async def _cluster_options(self): url = "%s/gateway/api/clusters/options" % self.address req = HTTPRequest(url=url, method="GET") resp = await self._fetch(req) data = json.loads(resp.body) return Options._from_spec(data["cluster_options"]) def cluster_options(self, **kwargs): """Get the available cluster configuration options. Returns ------- cluster_options : Options A dict of cluster options. """ return self.sync(self._cluster_options, **kwargs) async def _submit(self, cluster_options=None, **kwargs): url = "%s/gateway/api/clusters/" % self.address if cluster_options is not None: if not isinstance(cluster_options, Options): raise TypeError( "cluster_options must be an `Options`, got %r" % type(cluster_options).__name__) options = dict(cluster_options) options.update(kwargs) else: options = kwargs req = HTTPRequest( url=url, method="POST", body=json.dumps({"cluster_options": options}), headers=HTTPHeaders({"Content-type": "application/json"}), ) resp = await self._fetch(req) data = json.loads(resp.body) return data["name"] def submit(self, cluster_options=None, **kwargs): """Submit a new cluster to be started. This returns quickly with a ``cluster_name``, which can later be used to connect to the cluster. Parameters ---------- cluster_options : Options, optional An ``Options`` object describing the desired cluster configuration. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster_name : str The cluster name. """ return self.sync(self._submit, **kwargs) async def _cluster_report(self, cluster_name, wait=False): params = "?wait" if wait else "" url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name, params) req = HTTPRequest(url=url) resp = await self._fetch(req) return ClusterReport._from_json(self.address, self.proxy_address, json.loads(resp.body)) async def _connect(self, cluster_name): while True: try: report = await self._cluster_report(cluster_name, wait=True) except TimeoutError: # Timeout, ignore pass else: if report.status is ClusterStatus.RUNNING: return GatewayCluster( gateway=self, name=report.name, scheduler_address=report.scheduler_address, dashboard_link=report.dashboard_link, security=report.security, ) elif report.status is ClusterStatus.FAILED: raise GatewayClusterError( "Cluster %r failed to start, see logs for " "more information" % cluster_name) elif report.status is ClusterStatus.STOPPED: raise GatewayClusterError("Cluster %r is already stopped" % cluster_name) # Not started yet, try again later await gen.sleep(0.5) def connect(self, cluster_name, **kwargs): """Connect to a submitted cluster. Returns ------- cluster : GatewayCluster """ return self.sync(self._connect, cluster_name, **kwargs) async def _new_cluster(self, **kwargs): cluster_name = await self._submit(**kwargs) try: return await self._connect(cluster_name) except GatewayClusterError: raise except BaseException: # Ensure cluster is stopped on error await self._stop_cluster(cluster_name) raise def new_cluster(self, cluster_options=None, **kwargs): """Submit a new cluster to the gateway, and wait for it to be started. Same as calling ``submit`` and ``connect`` in one go. Parameters ---------- cluster_options : Options, optional An ``Options`` object describing the desired cluster configuration. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster : GatewayCluster """ return self.sync(self._new_cluster, cluster_options=cluster_options, **kwargs) async def _stop_cluster(self, cluster_name): url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name) req = HTTPRequest(url=url, method="DELETE") await self._fetch(req) def stop_cluster(self, cluster_name, **kwargs): """Stop a cluster. Parameters ---------- cluster_name : str The cluster name. """ return self.sync(self._stop_cluster, cluster_name, **kwargs) async def _scale_cluster(self, cluster_name, n): url = "%s/gateway/api/clusters/%s/workers" % (self.address, cluster_name) req = HTTPRequest( url=url, method="PUT", body=json.dumps({"worker_count": n}), headers=HTTPHeaders({"Content-type": "application/json"}), ) await self._fetch(req) def scale_cluster(self, cluster_name, n, **kwargs): """Scale a cluster to n workers. Parameters ---------- cluster_name : str The cluster name. n : int The number of workers to scale to. """ return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
class Gateway(object): """A client for a Dask Gateway Server. Parameters ---------- address : str, optional The address to the gateway server. auth : GatewayAuth, optional The authentication method to use. asynchronous : bool, optional If true, starts the client in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ def __init__(self, address=None, auth=None, asynchronous=False, loop=None): if address is None: address = dask.config.get("gateway.address") if address is None: raise ValueError( "No dask-gateway address provided or found in configuration") self.address = address.rstrip("/") self._auth = get_auth(auth) self._cookie_jar = CookieJar() self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop self._loop_runner.start() if self._asynchronous: self._started = self._start() else: self.sync(self._start) async def _start(self): self._http_client = AsyncHTTPClient() def close(self): """Close this gateway client""" if not self.asynchronous: self._loop_runner.stop() def __del__(self): # __del__ is still called, even if __init__ failed. Only close if init # actually succeeded. if hasattr(self, "_started"): self.close() def __enter__(self): return self def __exit__(self, *args): self.close() def __await__(self): return self._started.__await__() async def __aenter__(self): await self._started return self async def __aexit__(self, typ, value, traceback): pass def __repr__(self): return "Gateway<%s>" % self.address @property def asynchronous(self): return (self._asynchronous or getattr(thread_state, "asynchronous", False) or (hasattr(self.loop, "_thread_identity") and self.loop._thread_identity == get_ident())) def sync(self, func, *args, **kwargs): if kwargs.pop("asynchronous", None) or self.asynchronous: callback_timeout = kwargs.pop("callback_timeout", None) future = func(*args, **kwargs) if callback_timeout is not None: future = gen.with_timeout(timedelta(seconds=callback_timeout), future) return future else: return sync(self.loop, func, *args, **kwargs) async def _fetch(self, req, raise_error=True): self._cookie_jar.pre_request(req) resp = await self._http_client.fetch(req, raise_error=False) if resp.code == 401: context = self._auth.pre_request(req, resp) resp = await self._http_client.fetch(req, raise_error=False) self._auth.post_response(req, resp, context) self._cookie_jar.post_response(resp) if raise_error: resp.rethrow() return resp async def _clusters(self, status=None): if status is not None: if isinstance(status, (str, ClusterStatus)): status = [ClusterStatus._create(status)] else: status = [ClusterStatus._create(s) for s in status] query = "?status=" + ",".join(s.name for s in status) else: query = "" url = "%s/gateway/api/clusters/%s" % (self.address, query) req = HTTPRequest(url=url) resp = await self._fetch(req) return [ ClusterReport._from_json(self.address, r) for r in json.loads(resp.body).values() ] def list_clusters(self, status=None, **kwargs): """List clusters for this user. Parameters ---------- status : ClusterStatus, str, or list, optional The cluster status (or statuses) to select. Valid options are 'starting', 'started', 'running', 'stopping', 'stopped', 'failed'. By default selects active clusters ('starting', 'started', 'running'). Returns ------- clusters : list of ClusterReport """ return self.sync(self._clusters, status=status, **kwargs) async def _submit(self): url = "%s/gateway/api/clusters/" % self.address req = HTTPRequest( url=url, method="POST", body=json.dumps({}), headers=HTTPHeaders({"Content-type": "application/json"}), ) resp = await self._fetch(req) data = json.loads(resp.body) return data["name"] def submit(self, **kwargs): """Submit a new cluster to be started. This returns quickly with a ``cluster_name``, which can later be used to connect to the cluster. Returns ------- cluster_name : str The cluster name. """ return self.sync(self._submit, **kwargs) async def _cluster_report(self, cluster_name, wait=False): params = "?wait" if wait else "" url = "%s/gateway/api/clusters/%s%s" % (self.address, cluster_name, params) req = HTTPRequest(url=url) resp = await self._fetch(req) return ClusterReport._from_json(self.address, json.loads(resp.body)) async def _connect(self, cluster_name): while True: try: report = await self._cluster_report(cluster_name, wait=True) except HTTPError as exc: if exc.code == 404: raise Exception("Unknown cluster %r" % cluster_name) elif exc.code == 599: # Timeout, ignore pass else: raise else: if report.status is ClusterStatus.RUNNING: return GatewayCluster(self, report) elif report.status is ClusterStatus.FAILED: raise Exception("Cluster %r failed to start, see logs for " "more information" % cluster_name) elif report.status is ClusterStatus.STOPPED: raise Exception("Cluster %r is already stopped" % cluster_name) # Not started yet, try again later await gen.sleep(0.5) def connect(self, cluster_name, **kwargs): """Connect to a submitted cluster. Returns ------- cluster : GatewayCluster """ return self.sync(self._connect, cluster_name, **kwargs) async def _new_cluster(self, **kwargs): cluster_name = await self._submit(**kwargs) try: return await self._connect(cluster_name) except BaseException: # Ensure cluster is stopped on error await self._stop_cluster(cluster_name) raise def new_cluster(self, **kwargs): """Submit a new cluster to the gateway, and wait for it to be started. Same as calling ``submit`` and ``connect`` in one go. Returns ------- cluster : GatewayCluster """ return self.sync(self._new_cluster, **kwargs) async def _stop_cluster(self, cluster_name): url = "%s/gateway/api/clusters/%s" % (self.address, cluster_name) req = HTTPRequest(url=url, method="DELETE") await self._fetch(req) def stop_cluster(self, cluster_name, **kwargs): """Stop a cluster. Parameters ---------- cluster_name : str The cluster name. """ return self.sync(self._stop_cluster, cluster_name, **kwargs) async def _scale_cluster(self, cluster_name, n): url = "%s/gateway/api/clusters/%s/workers" % (self.address, cluster_name) req = HTTPRequest( url=url, method="PUT", body=json.dumps({"worker_count": n}), headers=HTTPHeaders({"Content-type": "application/json"}), ) try: await self._fetch(req) except HTTPError as exc: if exc.code == 409: raise Exception("Cluster %r is not running" % cluster_name) raise def scale_cluster(self, cluster_name, n, **kwargs): """Scale a cluster to n workers. Parameters ---------- cluster_name : str The cluster name. n : int The number of workers to scale to. """ return self.sync(self._scale_cluster, cluster_name, n, **kwargs)
class Gateway(object): """A client for a Dask Gateway Server. Parameters ---------- address : str, optional The address to the gateway server. proxy_address : str, int, optional The address of the scheduler proxy server. Defaults to `address` if not provided. If an int, it's used as the port, with the host/ip taken from ``address``. Provide a full address if a different host/ip should be used. auth : GatewayAuth, optional The authentication method to use. asynchronous : bool, optional If true, starts the client in asynchronous mode, where it can be used in other async code. loop : IOLoop, optional The IOLoop instance to use. Defaults to the current loop in asynchronous mode, otherwise a background loop is started. """ def __init__(self, address=None, proxy_address=None, auth=None, asynchronous=False, loop=None): if address is None: address = format_template(dask.config.get("gateway.address")) if address is None: raise ValueError( "No dask-gateway address provided or found in configuration") address = address.rstrip("/") public_address = format_template( dask.config.get("gateway.public-address")) if public_address is None: public_address = address else: public_address = public_address.rstrip("/") if proxy_address is None: proxy_address = format_template( dask.config.get("gateway.proxy-address")) if proxy_address is None: parsed = urlparse(address) if parsed.netloc: if parsed.port is None: proxy_port = { "http": 80, "https": 443 }.get(parsed.scheme, 8786) else: proxy_port = parsed.port proxy_netloc = "%s:%d" % (parsed.hostname, proxy_port) else: proxy_netloc = proxy_address elif isinstance(proxy_address, int): parsed = urlparse(address) proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address) elif isinstance(proxy_address, str): parsed = urlparse(proxy_address) proxy_netloc = parsed.netloc if parsed.netloc else proxy_address proxy_address = "gateway://%s" % proxy_netloc scheme = urlparse(address).scheme self._request_kwargs = _get_default_request_kwargs(scheme) self.address = address self._public_address = public_address self.proxy_address = proxy_address self.auth = get_auth(auth) self._session = None self._asynchronous = asynchronous self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self._loop_runner.start() @property def loop(self): return self._loop_runner.loop @property def asynchronous(self): return self._asynchronous def sync(self, func, *args, **kwargs): if self.asynchronous: return func(*args, **kwargs) else: future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop.asyncio_loop) try: return future.result() except BaseException: future.cancel() raise def close(self): """Close the gateway client""" if self.asynchronous: return self._cleanup() elif self.loop.asyncio_loop.is_running(): self.sync(self._cleanup) self._loop_runner.stop() async def _cleanup(self): if self._session is not None: await self._session.close() self._session = None async def __aenter__(self): return self async def __aexit__(self, typ, value, traceback): await self._cleanup() def __enter__(self): return self def __exit__(self, *args): self.close() def __del__(self): if (not self.asynchronous and hasattr(self, "_loop_runner") and not sys.is_finalizing()): self.close() def __repr__(self): return "Gateway<%s>" % self.address async def _request(self, method, url, json=None): if self._session is None: self._session = aiohttp.ClientSession() session = self._session resp = await session.request(method, url, json=json, **self._request_kwargs) if resp.status == 401: headers, context = self.auth.pre_request(resp) resp = await session.request(method, url, json=json, headers=headers, **self._request_kwargs) self.auth.post_response(resp, context) if resp.status >= 400: try: msg = await resp.json() msg = msg["error"] except Exception: msg = await resp.text() if resp.status in {404, 422}: raise ValueError(msg) elif resp.status == 409: raise GatewayClusterError(msg) elif resp.status == 500: raise GatewayServerError(msg) else: resp.raise_for_status() else: return resp async def _clusters(self, status=None): if status is not None: if isinstance(status, (str, ClusterStatus)): status = [ClusterStatus._create(status)] else: status = [ClusterStatus._create(s) for s in status] query = "?status=" + ",".join(s.name for s in status) else: query = "" url = "%s/api/v1/clusters/%s" % (self.address, query) resp = await self._request("GET", url) data = await resp.json() return [ ClusterReport._from_json(self._public_address, self.proxy_address, r) for r in data.values() ] def list_clusters(self, status=None, **kwargs): """List clusters for this user. Parameters ---------- status : ClusterStatus, str, or list, optional The cluster status (or statuses) to select. Valid options are 'pending', 'running', 'stopping', 'stopped', 'failed'. By default selects active clusters ('pending', 'running'). Returns ------- clusters : list of ClusterReport """ return self.sync(self._clusters, status=status, **kwargs) def get_cluster(self, cluster_name, **kwargs): """Get information about a specific cluster. Parameters ---------- cluster_name : str The cluster name. Returns ------- report : ClusterReport """ return self.sync(self._cluster_report, cluster_name, **kwargs) async def _get_versions(self): url = "%s/api/version" % self.address resp = await self._request("GET", url) server_info = await resp.json() from . import __version__ return { "server": server_info, "client": { "version": __version__ }, } def get_versions(self): """Return version info for the server and client Returns ------- version_info : dict """ return self.sync(self._get_versions) def _config_cluster_options(self): opts = dask.config.get("gateway.cluster.options") return {k: format_template(v) for k, v in opts.items()} async def _cluster_options(self, use_local_defaults=True): url = "%s/api/v1/options" % self.address resp = await self._request("GET", url) data = await resp.json() options = Options._from_spec(data["cluster_options"]) if use_local_defaults: options.update(self._config_cluster_options()) return options def cluster_options(self, use_local_defaults=True, **kwargs): """Get the available cluster configuration options. Parameters ---------- use_local_defaults : bool, optional Whether to use any default options from the local configuration. Default is True, set to False to use only the server-side defaults. Returns ------- cluster_options : dask_gateway.options.Options A dict of cluster options. """ return self.sync(self._cluster_options, use_local_defaults=use_local_defaults, **kwargs) async def _submit(self, cluster_options=None, **kwargs): url = "%s/api/v1/clusters/" % self.address if cluster_options is not None: if not isinstance(cluster_options, Options): raise TypeError( "cluster_options must be an `Options`, got %r" % type(cluster_options).__name__) options = dict(cluster_options) options.update(kwargs) else: options = self._config_cluster_options() options.update(kwargs) resp = await self._request("POST", url, json={"cluster_options": options}) data = await resp.json() return data["name"] def submit(self, cluster_options=None, **kwargs): """Submit a new cluster to be started. This returns quickly with a ``cluster_name``, which can later be used to connect to the cluster. Parameters ---------- cluster_options : dask_gateway.options.Options, optional An ``Options`` object describing the desired cluster configuration. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster_name : str The cluster name. """ return self.sync(self._submit, cluster_options=cluster_options, **kwargs) async def _cluster_report(self, cluster_name, wait=False): params = "?wait" if wait else "" url = "%s/api/v1/clusters/%s%s" % (self.address, cluster_name, params) resp = await self._request("GET", url) data = await resp.json() return ClusterReport._from_json(self._public_address, self.proxy_address, data) async def _wait_for_start(self, cluster_name): while True: try: report = await self._cluster_report(cluster_name, wait=True) except TimeoutError: # Timeout, ignore pass else: if report.status is ClusterStatus.RUNNING: return report elif report.status is ClusterStatus.FAILED: raise GatewayClusterError( "Cluster %r failed to start, see logs for " "more information" % cluster_name) elif report.status is ClusterStatus.STOPPED: raise GatewayClusterError("Cluster %r is already stopped" % cluster_name) # Not started yet, try again later await asyncio.sleep(0.5) def connect(self, cluster_name, shutdown_on_close=False): """Connect to a submitted cluster. Parameters ---------- cluster_name : str The cluster to connect to. shutdown_on_close : bool, optional If True, the cluster will be automatically shutdown on close. Default is False. Returns ------- cluster : GatewayCluster """ return GatewayCluster.from_name( cluster_name, shutdown_on_close=shutdown_on_close, address=self.address, proxy_address=self.proxy_address, auth=self.auth, asynchronous=self.asynchronous, loop=self.loop, ) def new_cluster(self, cluster_options=None, shutdown_on_close=True, **kwargs): """Submit a new cluster to the gateway, and wait for it to be started. Same as calling ``submit`` and ``connect`` in one go. Parameters ---------- cluster_options : dask_gateway.options.Options, optional An ``Options`` object describing the desired cluster configuration. shutdown_on_close : bool, optional If True (default), the cluster will be automatically shutdown on close. Set to False to have cluster persist until explicitly shutdown. **kwargs : Additional cluster configuration options. If ``cluster_options`` is provided, these are applied afterwards as overrides. Available options are specific to each deployment of dask-gateway, see ``cluster_options`` for more information. Returns ------- cluster : GatewayCluster """ return GatewayCluster( address=self.address, proxy_address=self.proxy_address, auth=self.auth, asynchronous=self.asynchronous, loop=self.loop, cluster_options=cluster_options, shutdown_on_close=shutdown_on_close, **kwargs, ) async def _stop_cluster(self, cluster_name): url = "%s/api/v1/clusters/%s" % (self.address, cluster_name) await self._request("DELETE", url) def stop_cluster(self, cluster_name, **kwargs): """Stop a cluster. Parameters ---------- cluster_name : str The cluster name. """ return self.sync(self._stop_cluster, cluster_name, **kwargs) async def _scale_cluster(self, cluster_name, n): url = "%s/api/v1/clusters/%s/scale" % (self.address, cluster_name) await self._request("POST", url, json={"count": n}) def scale_cluster(self, cluster_name, n, **kwargs): """Scale a cluster to n workers. Parameters ---------- cluster_name : str The cluster name. n : int The number of workers to scale to. """ return self.sync(self._scale_cluster, cluster_name, n, **kwargs) async def _adapt_cluster(self, cluster_name, minimum=None, maximum=None, active=True): await self._request( "POST", "%s/api/v1/clusters/%s/adapt" % (self.address, cluster_name), json={ "minimum": minimum, "maximum": maximum, "active": active }, ) def adapt_cluster(self, cluster_name, minimum=None, maximum=None, active=True, **kwargs): """Configure adaptive scaling for a cluster. Parameters ---------- cluster_name : str The cluster name. minimum : int, optional The minimum number of workers to scale to. Defaults to 0. maximum : int, optional The maximum number of workers to scale to. Defaults to infinity. active : bool, optional If ``True`` (default), adaptive scaling is activated. Set to ``False`` to deactivate adaptive scaling. """ return self.sync( self._adapt_cluster, cluster_name, minimum=minimum, maximum=maximum, active=active, **kwargs, )