def test_security_temporary(EnvSpecificCluster, loop): dirname = os.path.dirname(__file__) with EnvSpecificCluster( cores=1, memory="100MB", security=Security.temporary(), shared_temp_directory=dirname, protocol="tls", loop=loop, ) as cluster: assert cluster.security assert cluster.scheduler_spec["options"][ "security"] == cluster.security job_script = cluster.job_script() assert "tls://" in job_script keyfile = re.findall(r"--tls-key (\S+)", job_script)[0] assert ( os.path.exists(keyfile) and os.path.basename(keyfile).startswith(".dask-jobqueue.worker.key") and os.path.dirname(keyfile) == dirname) certfile = re.findall(r"--tls-cert (\S+)", job_script)[0] assert ( os.path.exists(certfile) and os.path.basename(certfile).startswith(".dask-jobqueue.worker.cert") and os.path.dirname(certfile) == dirname) cafile = re.findall(r"--tls-ca-file (\S+)", job_script)[0] assert (os.path.exists(cafile) and os.path.basename(cafile).startswith( ".dask-jobqueue.worker.ca_file") and os.path.dirname(cafile) == dirname) cluster.scale(jobs=1) with Client(cluster) as client: future = client.submit(lambda x: x + 1, 10) result = future.result(timeout=30) assert result == 11
def test_extra_conn_args_in_temporary_credentials(): pytest.importorskip("cryptography") sec = Security.temporary( extra_conn_args={"headers": { "X-Request-ID": "abcd" }}) assert sec.extra_conn_args == {"headers": {"X-Request-ID": "abcd"}}
def test_temporary_credentials(): sec = Security.temporary() sec_repr = repr(sec) fields = ["tls_ca_file"] fields.extend("tls_%s_%s" % (role, kind) for role in ["client", "scheduler", "worker"] for kind in ["key", "cert"]) for f in fields: val = getattr(sec, f) assert "\n" in val assert val not in sec_repr
def test_temporary_credentials(): pytest.importorskip("cryptography") sec = Security.temporary() sec_repr = repr(sec) fields = ["tls_ca_file"] fields.extend(f"tls_{role}_{kind}" for role in ["client", "scheduler", "worker"] for kind in ["key", "cert"]) for f in fields: val = getattr(sec, f) assert "\n" in val assert val not in sec_repr
async def test_wss_roundtrip(): np = pytest.importorskip("numpy") xfail_ssl_issue5601() pytest.importorskip("cryptography") security = Security.temporary() async with Scheduler( protocol="wss://", security=security, dashboard_address=":0" ) as s: async with Worker(s.address, security=security) as w: async with Client(s.address, security=security, asynchronous=True) as c: x = np.arange(100) future = await c.scatter(x) y = await future assert (x == y).all()
async def test_http_and_comm_server(cleanup, dashboard, protocol, security, port): if security: xfail_ssl_issue5601() pytest.importorskip("cryptography") security = Security.temporary() async with Scheduler( protocol=protocol, dashboard=dashboard, port=port, security=security ) as s: if port == 8787: assert s.http_server is s.listener.server else: assert s.http_server is not s.listener.server async with Worker(s.address, protocol=protocol, security=security) as w: async with Client(s.address, asynchronous=True, security=security) as c: result = await c.submit(lambda x: x + 1, 10) assert result == 11
async def test_tls_temporary_credentials_functional(): pytest.importorskip("cryptography") async def handle_comm(comm): peer_addr = comm.peer_address assert peer_addr.startswith("tls://") await comm.write("hello") await comm.close() sec = Security.temporary() async with listen("tls://", handle_comm, **sec.get_listen_args("scheduler")) as listener: comm = await connect(listener.contact_address, **sec.get_connection_args("worker")) msg = await comm.read() assert msg == "hello" comm.abort()
async def test_expect_scheduler_ssl_when_sharing_server(tmpdir): xfail_ssl_issue5601() pytest.importorskip("cryptography") security = Security.temporary() key_path = os.path.join(str(tmpdir), "dask.pem") cert_path = os.path.join(str(tmpdir), "dask.crt") with open(key_path, "w") as f: f.write(security.tls_scheduler_key) with open(cert_path, "w") as f: f.write(security.tls_scheduler_cert) c = { "distributed.scheduler.dashboard.tls.key": key_path, "distributed.scheduler.dashboard.tls.cert": cert_path, } with dask.config.set(c): with pytest.raises(RuntimeError): async with Scheduler(protocol="ws://", dashboard=True, port=8787): pass
async def test_connection_made_with_extra_conn_args(cleanup, protocol): if protocol == "ws://": security = Security( extra_conn_args={"headers": {"Authorization": "Token abcd"}} ) else: xfail_ssl_issue5601() pytest.importorskip("cryptography") security = Security.temporary( extra_conn_args={"headers": {"Authorization": "Token abcd"}} ) async with Scheduler( protocol=protocol, security=security, dashboard_address=":0" ) as s: connection_args = security.get_connection_args("worker") comm = await connect(s.address, **connection_args) assert comm.sock.request.headers.get("Authorization") == "Token abcd" await comm.close()
def __init__( self, n_workers=0, job_cls: Job = None, # Cluster keywords loop=None, security=None, shared_temp_directory=None, silence_logs="error", name=None, asynchronous=False, # Scheduler-only keywords dashboard_address=None, host=None, scheduler_options=None, scheduler_cls=Scheduler, # Use local scheduler for now # Options for both scheduler and workers interface=None, protocol=None, # Job keywords config_name=None, **job_kwargs): self.status = Status.created default_job_cls = getattr(type(self), "job_cls", None) self.job_cls = default_job_cls if job_cls is not None: self.job_cls = job_cls if self.job_cls is None: raise ValueError( "You need to specify a Job type. Two cases:\n" "- you are inheriting from JobQueueCluster (most likely): you need to add a 'job_cls' class variable " "in your JobQueueCluster-derived class {}\n" "- you are using JobQueueCluster directly (less likely, only useful for tests): " "please explicitly pass a Job type through the 'job_cls' parameter." .format(type(self))) if dashboard_address is not None: raise ValueError( "Please pass 'dashboard_address' through 'scheduler_options': use\n" 'cluster = {0}(..., scheduler_options={{"dashboard_address": ":12345"}}) rather than\n' 'cluster = {0}(..., dashboard_address="12435")'.format( self.__class__.__name__)) if host is not None: raise ValueError( "Please pass 'host' through 'scheduler_options': use\n" 'cluster = {0}(..., scheduler_options={{"host": "your-host"}}) rather than\n' 'cluster = {0}(..., host="your-host")'.format( self.__class__.__name__)) default_config_name = self.job_cls.default_config_name() if config_name is None: config_name = default_config_name if interface is None: interface = dask.config.get("jobqueue.%s.interface" % config_name) if scheduler_options is None: scheduler_options = dask.config.get( "jobqueue.%s.scheduler-options" % config_name, {}) if protocol is None and security is not None: protocol = "tls://" if security is True: try: security = Security.temporary() except ImportError: raise ImportError( "In order to use TLS without pregenerated certificates `cryptography` is required," "please install it using either pip or conda") default_scheduler_options = { "protocol": protocol, "dashboard_address": ":8787", "security": security, } # scheduler_options overrides parameters common to both workers and scheduler scheduler_options = dict(default_scheduler_options, **scheduler_options) # Use the same network interface as the workers if scheduler ip has not # been set through scheduler_options via 'host' or 'interface' if "host" not in scheduler_options and "interface" not in scheduler_options: scheduler_options["interface"] = interface scheduler = { "cls": scheduler_cls, "options": scheduler_options, } if shared_temp_directory is None: shared_temp_directory = dask.config.get( "jobqueue.%s.shared-temp-directory" % config_name) self.shared_temp_directory = shared_temp_directory job_kwargs["config_name"] = config_name job_kwargs["interface"] = interface job_kwargs["protocol"] = protocol job_kwargs["security"] = self._get_worker_security(security) self._job_kwargs = job_kwargs worker = {"cls": self.job_cls, "options": self._job_kwargs} if "processes" in self._job_kwargs and self._job_kwargs[ "processes"] > 1: worker["group"] = [ "-" + str(i) for i in range(self._job_kwargs["processes"]) ] self._dummy_job # trigger property to ensure that the job is valid super().__init__( scheduler=scheduler, worker=worker, loop=loop, security=security, silence_logs=silence_logs, asynchronous=asynchronous, name=name, ) if n_workers: self.scale(n_workers)
def test_repr_temp_keys(): sec = Security.temporary() representation = repr(sec) assert "Temporary (In-memory)" in representation
def __init__( self, n_workers: int = 0, worker_class: str = "dask.distributed.Nanny", worker_options: dict = {}, scheduler_options: dict = {}, docker_image="daskdev/dask:latest", docker_args: str = "", env_vars: dict = {}, security: bool = True, protocol: str = None, **kwargs, ): if self.scheduler_class is None or self.worker_class is None: raise RuntimeError( "VMCluster is not intended to be used directly. See docstring for more info." ) self._n_workers = n_workers if not security: self.security = None elif security is True: # True indicates self-signed temporary credentials should be used self.security = Security.temporary() elif not isinstance(security, Security): raise TypeError("security must be a Security object") else: self.security = security if protocol is None: if self.security and self.security.require_encryption: self.protocol = "tls" else: self.protocol = "tcp" else: self.protocol = protocol if self.security and self.security.require_encryption: dask.config.set({ "distributed.comm.default-scheme": self.protocol, "distributed.comm.require-encryption": True, "distributed.comm.tls.ca-file": self.security.tls_ca_file, "distributed.comm.tls.scheduler.key": self.security.tls_scheduler_key, "distributed.comm.tls.scheduler.cert": self.security.tls_scheduler_cert, "distributed.comm.tls.worker.key": self.security.tls_worker_key, "distributed.comm.tls.worker.cert": self.security.tls_worker_cert, "distributed.comm.tls.client.key": self.security.tls_client_key, "distributed.comm.tls.client.cert": self.security.tls_client_cert, }) image = self.scheduler_options.get("docker_image", False) or docker_image self.options["docker_image"] = image self.scheduler_options["docker_image"] = image self.scheduler_options["env_vars"] = env_vars self.scheduler_options["protocol"] = protocol self.scheduler_options["scheduler_options"] = scheduler_options self.worker_options["env_vars"] = env_vars self.options["docker_args"] = docker_args self.scheduler_options["docker_args"] = docker_args self.worker_options["docker_args"] = docker_args self.worker_options["docker_image"] = image self.worker_options["worker_class"] = worker_class self.worker_options["protocol"] = protocol self.worker_options["worker_options"] = worker_options self.uuid = str(uuid.uuid4())[:8] super().__init__(**kwargs, security=self.security)
import pytest import dask from distributed import Client, Scheduler, Worker from distributed.comm import connect, listen, ws from distributed.comm.core import FatalCommClosedError from distributed.comm.registry import backends, get_backend from distributed.security import Security from distributed.utils_test import ( # noqa: F401 cleanup, gen_cluster, get_client_ssl_context, get_server_ssl_context, inc, ) from .test_comms import check_tls_extra security = Security.temporary() def test_registered(): assert "ws" in backends backend = get_backend("ws") assert isinstance(backend, ws.WSBackend) @pytest.mark.asyncio async def test_listen_connect(cleanup): async def handle_comm(comm): while True: msg = await comm.read() await comm.write(msg)
def test_repr_temp_keys(): xfail_ssl_issue5601() pytest.importorskip("cryptography") sec = Security.temporary() representation = repr(sec) assert "Temporary (In-memory)" in representation
def __init__( self, name=None, n_workers=None, threads_per_worker=None, processes=None, loop=None, start=None, host=None, ip=None, scheduler_port=0, silence_logs=logging.WARN, dashboard_address=":8787", worker_dashboard_address=None, diagnostics_port=None, services=None, worker_services=None, service_kwargs=None, asynchronous=False, security=None, protocol=None, blocked_handlers=None, interface=None, worker_class=None, scheduler_kwargs=None, scheduler_sync_interval=1, **worker_kwargs, ): if ip is not None: # In the future we should warn users about this move # warnings.warn("The ip keyword has been moved to host") host = ip if diagnostics_port is not None: warnings.warn("diagnostics_port has been deprecated. " "Please use `dashboard_address=` instead") dashboard_address = diagnostics_port if threads_per_worker == 0: warnings.warn( "Setting `threads_per_worker` to 0 has been deprecated. " "Please set to None or to a specific int.") threads_per_worker = None if "dashboard" in worker_kwargs: warnings.warn( "Setting `dashboard` is discouraged. " "Please set `dashboard_address` to affect the scheduler (more common) " "and `worker_dashboard_address` for the worker (less common).") if processes is None: processes = worker_class is None or issubclass(worker_class, Nanny) if worker_class is None: worker_class = Nanny if processes else Worker self.status = None self.processes = processes if security is None: # Falsey values load the default configuration security = Security() elif security is True: # True indicates self-signed temporary credentials should be used security = Security.temporary() elif not isinstance(security, Security): raise TypeError("security must be a Security object") if protocol is None: if host and "://" in host: protocol = host.split("://")[0] elif security and security.require_encryption: protocol = "tls://" elif not self.processes and not scheduler_port: protocol = "inproc://" else: protocol = "tcp://" if not protocol.endswith("://"): protocol = protocol + "://" if host is None and not protocol.startswith( "inproc") and not interface: host = "127.0.0.1" services = services or {} worker_services = worker_services or {} if n_workers is None and threads_per_worker is None: if processes: n_workers, threads_per_worker = nprocesses_nthreads() else: n_workers = 1 threads_per_worker = CPU_COUNT if n_workers is None and threads_per_worker is not None: n_workers = max(1, CPU_COUNT // threads_per_worker) if processes else 1 if n_workers and threads_per_worker is None: # Overcommit threads per worker, rather than undercommit threads_per_worker = max(1, int(math.ceil(CPU_COUNT / n_workers))) if n_workers and "memory_limit" not in worker_kwargs: worker_kwargs["memory_limit"] = parse_memory_limit( "auto", 1, n_workers) worker_kwargs.update({ "host": host, "nthreads": threads_per_worker, "services": worker_services, "dashboard_address": worker_dashboard_address, "dashboard": worker_dashboard_address is not None, "interface": interface, "protocol": protocol, "security": security, "silence_logs": silence_logs, }) scheduler = { "cls": Scheduler, "options": toolz.merge( dict( host=host, services=services, service_kwargs=service_kwargs, security=security, port=scheduler_port, interface=interface, protocol=protocol, dashboard=dashboard_address is not None, dashboard_address=dashboard_address, blocked_handlers=blocked_handlers, ), scheduler_kwargs or {}, ), } worker = {"cls": worker_class, "options": worker_kwargs} workers = {i: worker for i in range(n_workers)} super().__init__( name=name, scheduler=scheduler, workers=workers, worker=worker, loop=loop, asynchronous=asynchronous, silence_logs=silence_logs, security=security, scheduler_sync_interval=scheduler_sync_interval, )
from distributed import Client, Scheduler, Worker from distributed.comm import connect, listen, ws from distributed.comm.registry import backends, get_backend from distributed.security import Security from distributed.utils_test import ( # noqa: F401 cleanup, gen_cluster, get_client_ssl_context, get_server_ssl_context, inc, ) from .test_comms import check_tls_extra security = Security.temporary() def test_registered(): assert "ws" in backends backend = get_backend("ws") assert isinstance(backend, ws.WSBackend) @pytest.mark.asyncio async def test_listen_connect(cleanup): async def handle_comm(comm): while True: msg = await comm.read() await comm.write(msg)