def test_attribute_error(): sec = Security() assert hasattr(sec, 'tls_ca_file') with pytest.raises(AttributeError): sec.tls_foobar with pytest.raises(AttributeError): sec.tls_foobar = ""
def test_require_encryption(): """ Functional test for "require_encryption" setting. """ @gen.coroutine def handle_comm(comm): comm.abort() c = { 'tls': { 'ca-file': ca_file, 'scheduler': { 'key': key1, 'cert': cert1, }, 'worker': { 'cert': keycert1, }, }, } with new_config(c): sec = Security() c['require-encryption'] = True with new_config(c): sec2 = Security() for listen_addr in ['inproc://', 'tls://']: with listen(listen_addr, handle_comm, connection_args=sec.get_listen_args('scheduler')) as listener: comm = yield connect(listener.contact_address, connection_args=sec2.get_connection_args('worker')) comm.abort() with listen(listen_addr, handle_comm, connection_args=sec2.get_listen_args('scheduler')) as listener: comm = yield connect(listener.contact_address, connection_args=sec2.get_connection_args('worker')) comm.abort() @contextmanager def check_encryption_error(): with pytest.raises(RuntimeError) as excinfo: yield assert "encryption required" in str(excinfo.value) for listen_addr in ['tcp://']: with listen(listen_addr, handle_comm, connection_args=sec.get_listen_args('scheduler')) as listener: comm = yield connect(listener.contact_address, connection_args=sec.get_connection_args('worker')) comm.abort() with pytest.raises(RuntimeError): yield connect(listener.contact_address, connection_args=sec2.get_connection_args('worker')) with pytest.raises(RuntimeError): listen(listen_addr, handle_comm, connection_args=sec2.get_listen_args('scheduler'))
def test_listen_args(): def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False def many_ciphers(ctx): if sys.version_info >= (3, 6): assert len(ctx.get_ciphers()) > 2 # Most likely c = { 'tls': { 'ca-file': ca_file, 'scheduler': { 'key': key1, 'cert': cert1, }, 'worker': { 'cert': keycert1, }, }, } with new_config(c): sec = Security() d = sec.get_listen_args('scheduler') assert not d['require_encryption'] ctx = d['ssl_context'] basic_checks(ctx) many_ciphers(ctx) d = sec.get_listen_args('worker') ctx = d['ssl_context'] basic_checks(ctx) many_ciphers(ctx) # No cert defined => no TLS d = sec.get_listen_args('client') assert d.get('ssl_context') is None # With more settings c['tls']['ciphers'] = FORCED_CIPHER c['require-encryption'] = True with new_config(c): sec = Security() d = sec.get_listen_args('scheduler') assert d['require_encryption'] ctx = d['ssl_context'] basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() tls_12_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.2'] assert len(tls_12_ciphers) == 1 tls_13_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.3'] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3
def test_tls_listen_connect(): """ Functional test for TLS connection args. """ @gen.coroutine def handle_comm(comm): peer_addr = comm.peer_address assert peer_addr.startswith('tls://') yield comm.write('hello') yield comm.close() c = { 'tls': { 'ca-file': ca_file, 'scheduler': { 'key': key1, 'cert': cert1, }, 'worker': { 'cert': keycert1, }, }, } with new_config(c): sec = Security() c['tls']['ciphers'] = FORCED_CIPHER with new_config(c): forced_cipher_sec = Security() with listen('tls://', handle_comm, connection_args=sec.get_listen_args('scheduler')) as listener: comm = yield connect(listener.contact_address, connection_args=sec.get_connection_args('worker')) msg = yield comm.read() assert msg == 'hello' comm.abort() # No SSL context for client with pytest.raises(TypeError): yield connect(listener.contact_address, connection_args=sec.get_connection_args('client')) # Check forced cipher comm = yield connect(listener.contact_address, connection_args=forced_cipher_sec.get_connection_args('worker')) cipher, _, _, = comm.extra_info['cipher'] assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS comm.abort()
def test_tls_config_for_role(): c = { 'tls': { 'ca-file': 'ca.pem', 'scheduler': { 'key': 'skey.pem', 'cert': 'scert.pem', }, 'worker': { 'cert': 'wcert.pem', }, 'ciphers': FORCED_CIPHER, }, } with new_config(c): sec = Security() t = sec.get_tls_config_for_role('scheduler') assert t == { 'ca_file': 'ca.pem', 'key': 'skey.pem', 'cert': 'scert.pem', 'ciphers': FORCED_CIPHER, } t = sec.get_tls_config_for_role('worker') assert t == { 'ca_file': 'ca.pem', 'key': None, 'cert': 'wcert.pem', 'ciphers': FORCED_CIPHER, } t = sec.get_tls_config_for_role('client') assert t == { 'ca_file': 'ca.pem', 'key': None, 'cert': None, 'ciphers': FORCED_CIPHER, } with pytest.raises(ValueError): sec.get_tls_config_for_role('supervisor')
def test_kwargs(): c = { 'tls': { 'ca-file': 'ca.pem', 'scheduler': { 'key': 'skey.pem', 'cert': 'scert.pem', }, }, } with new_config(c): sec = Security(tls_scheduler_cert='newcert.pem', require_encryption=True, tls_ca_file=None) assert sec.require_encryption is True # None value didn't override default assert sec.tls_ca_file == 'ca.pem' assert sec.tls_ciphers is None assert sec.tls_client_key is None assert sec.tls_client_cert is None assert sec.tls_scheduler_key == 'skey.pem' assert sec.tls_scheduler_cert == 'newcert.pem' assert sec.tls_worker_key is None assert sec.tls_worker_cert is None
def test_kwargs(): c = { "tls": { "ca-file": "ca.pem", "scheduler": { "key": "skey.pem", "cert": "scert.pem" }, } } with new_config(c): sec = Security(tls_scheduler_cert="newcert.pem", require_encryption=True, tls_ca_file=None) assert sec.require_encryption is True # None value didn't override default assert sec.tls_ca_file == "ca.pem" assert sec.tls_ciphers is None assert sec.tls_client_key is None assert sec.tls_client_cert is None assert sec.tls_scheduler_key == "skey.pem" assert sec.tls_scheduler_cert == "newcert.pem" assert sec.tls_worker_key is None assert sec.tls_worker_cert is None
def __init__( self, n_workers: int = 0, worker_class: str = "dask.distributed.Nanny", worker_options: dict = {}, scheduler_options: dict = {}, docker_image="daskdev/dask:latest", docker_args: str = "", env_vars: dict = {}, security: bool = True, protocol: str = None, **kwargs, ): if self.scheduler_class is None or self.worker_class is None: raise RuntimeError( "VMCluster is not intended to be used directly. See docstring for more info." ) self._n_workers = n_workers if not security: self.security = None elif security is True: # True indicates self-signed temporary credentials should be used self.security = Security.temporary() elif not isinstance(security, Security): raise TypeError("security must be a Security object") else: self.security = security if protocol is None: if self.security and self.security.require_encryption: self.protocol = "tls" else: self.protocol = "tcp" else: self.protocol = protocol if self.security and self.security.require_encryption: dask.config.set({ "distributed.comm.default-scheme": self.protocol, "distributed.comm.require-encryption": True, "distributed.comm.tls.ca-file": self.security.tls_ca_file, "distributed.comm.tls.scheduler.key": self.security.tls_scheduler_key, "distributed.comm.tls.scheduler.cert": self.security.tls_scheduler_cert, "distributed.comm.tls.worker.key": self.security.tls_worker_key, "distributed.comm.tls.worker.cert": self.security.tls_worker_cert, "distributed.comm.tls.client.key": self.security.tls_client_key, "distributed.comm.tls.client.cert": self.security.tls_client_cert, }) image = self.scheduler_options.get("docker_image", False) or docker_image self.options["docker_image"] = image self.scheduler_options["docker_image"] = image self.scheduler_options["env_vars"] = env_vars self.scheduler_options["protocol"] = protocol self.scheduler_options["scheduler_options"] = scheduler_options self.worker_options["env_vars"] = env_vars self.options["docker_args"] = docker_args self.scheduler_options["docker_args"] = docker_args self.worker_options["docker_args"] = docker_args self.worker_options["docker_image"] = image self.worker_options["worker_class"] = worker_class self.worker_options["protocol"] = protocol self.worker_options["worker_options"] = worker_options self.uuid = str(uuid.uuid4())[:8] super().__init__(**kwargs, security=self.security)
def test_repr_temp_keys(): sec = Security.temporary() representation = repr(sec) assert "Temporary (In-memory)" in representation
urllib_parse = six.moves.urllib_parse urlparse = urllib_parse.urlparse logger = logging.getLogger('distributed.preloading') from dask_signal import * from dask_usage import USAGE_INFO log_dir = os.path.expanduser('~/dask/logs/') # ATTENTION: this is the remote location ! # TLS 模式注意事项: 1.路径为绝对路径, '~' 和通配符 必须先展开; 2.Client的 addr指定为tls协议格式"tls://host:port" TLS_CA_FILE=os.path.expanduser('~/.dask/ca.crt') TLS_CA_CERT=os.path.expanduser('~/.dask/ca.crt') TLS_CA_KEY=os.path.expanduser('~/.dask/ca.key') SECURITY = SEC = Security(tls_ca_file=TLS_CA_FILE, tls_scheduler_cert=TLS_CA_CERT, tls_scheduler_key=TLS_CA_KEY) SECURITY_WORKER = Security(tls_ca_file=TLS_CA_FILE, tls_worker_cert=TLS_CA_CERT, tls_worker_key=TLS_CA_KEY) SECURITY_CLIENT = Security(tls_ca_file=TLS_CA_FILE, tls_client_cert=TLS_CA_CERT, tls_client_key=TLS_CA_KEY, require_encryption=True) SCHEDULER_PORT = 8786 GLOBAL_CLUSTER = list(map(lambda m:'tls://%s:%s'%(m, SCHEDULER_PORT), MASTERS))[0] SSH_USER = '******' SSH_MASTER_ip = 'gpu01.ops.zzyc.abael.com' SSH_PKEY = os.path.expanduser('~/.ssh/id_rsa') SSH_PUB = os.path.expanduser('~/.ssh/id_rsa.pub') SSH_WORKER_python = '/usr/bin/python3.6' SSH_PORT = 22 SSH_MASTER_port = 11111 SSH_NANNY_port = 22222
def main( scheduler, host, nthreads, name, memory_limit, device_memory_limit, rmm_pool_size, rmm_managed_memory, pid_file, resources, dashboard, dashboard_address, local_directory, scheduler_file, interface, death_timeout, preload, dashboard_prefix, tls_ca_file, tls_cert, tls_key, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, net_devices, **kwargs, ): if tls_ca_file and tls_cert and tls_key: security = Security( tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key, ) else: security = None worker = CUDAWorker( scheduler, host, nthreads, name, memory_limit, device_memory_limit, rmm_pool_size, rmm_managed_memory, pid_file, resources, dashboard, dashboard_address, local_directory, scheduler_file, interface, death_timeout, preload, dashboard_prefix, security, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, net_devices, **kwargs, ) async def on_signal(signum): logger.info("Exiting on signal %d", signum) await worker.close() async def run(): await worker await worker.finished() loop = IOLoop.current() install_signal_handlers(loop, cleanup=on_signal) try: loop.run_sync(run) except (KeyboardInterrupt, TimeoutError): pass finally: logger.info("End worker")
def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix, use_xheaders, pid_file, scheduler_file, interface, local_directory, preload, preload_argv, tls_ca_file, tls_cert, tls_key): enable_proctitle_on_current() enable_proctitle_on_children() sec = Security( tls_ca_file=tls_ca_file, tls_scheduler_cert=tls_cert, tls_scheduler_key=tls_key, ) if not host and (tls_ca_file or tls_cert or tls_key): host = 'tls://' if pid_file: with open(pid_file, 'w') as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) local_directory_created = False if local_directory: if not os.path.exists(local_directory): os.mkdir(local_directory) local_directory_created = True else: local_directory = tempfile.mkdtemp(prefix='scheduler-') local_directory_created = True if local_directory not in sys.path: sys.path.insert(0, local_directory) if sys.platform.startswith('linux'): import resource # module fails importing on Windows soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) if interface: if host: raise ValueError("Can not specify both interface and host") else: host = get_ip_interface(interface) addr = uri_from_host_port(host, port, 8786) loop = IOLoop.current() logger.info('-' * 47) services = {} if _bokeh: with ignoring(ImportError): from distributed.bokeh.scheduler import BokehScheduler services[('bokeh', bokeh_port)] = (BokehScheduler, { 'prefix': bokeh_prefix }) scheduler = Scheduler(loop=loop, services=services, scheduler_file=scheduler_file, security=sec) scheduler.start(addr) if not preload: preload = dask.config.get('distributed.scheduler.preload') if not preload_argv: preload_argv = dask.config.get('distributed.scheduler.preload-argv') preload_modules(preload, parameter=scheduler, file_dir=local_directory, argv=preload_argv) logger.info('Local Directory: %26s', local_directory) logger.info('-' * 47) install_signal_handlers(loop) try: loop.start() loop.close() finally: scheduler.stop() if local_directory_created: shutil.rmtree(local_directory) logger.info("End scheduler at %r", addr)
def test_listen_args(): def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False def many_ciphers(ctx): if sys.version_info >= (3, 6): assert len(ctx.get_ciphers()) > 2 # Most likely c = { 'tls': { 'ca-file': ca_file, 'scheduler': { 'key': key1, 'cert': cert1, }, 'worker': { 'cert': keycert1, }, }, } with new_config(c): sec = Security() d = sec.get_listen_args('scheduler') assert not d['require_encryption'] ctx = d['ssl_context'] basic_checks(ctx) many_ciphers(ctx) d = sec.get_listen_args('worker') ctx = d['ssl_context'] basic_checks(ctx) many_ciphers(ctx) # No cert defined => no TLS d = sec.get_listen_args('client') assert d.get('ssl_context') is None # With more settings c['tls']['ciphers'] = FORCED_CIPHER c['require-encryption'] = True with new_config(c): sec = Security() d = sec.get_listen_args('scheduler') assert d['require_encryption'] ctx = d['ssl_context'] basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() tls_12_ciphers = [ c for c in supported_ciphers if c['protocol'] == 'TLSv1.2' ] assert len(tls_12_ciphers) == 1 tls_13_ciphers = [ c for c in supported_ciphers if c['protocol'] == 'TLSv1.3' ] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3
def main(host, port, http_port, bokeh_port, bokeh_internal_port, show, _bokeh, bokeh_whitelist, bokeh_prefix, use_xheaders, pid_file, scheduler_file, interface, local_directory, preload, prefix, tls_ca_file, tls_cert, tls_key): if bokeh_internal_port: print("The --bokeh-internal-port keyword has been removed.\n" "The internal bokeh server is now the default bokeh server.\n" "Use --bokeh-port %d instead" % bokeh_internal_port) sys.exit(1) if prefix: print("The --prefix keyword has moved to --bokeh-prefix") sys.exit(1) sec = Security( tls_ca_file=tls_ca_file, tls_scheduler_cert=tls_cert, tls_scheduler_key=tls_key, ) if pid_file: with open(pid_file, 'w') as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) local_directory_created = False if local_directory: if not os.path.exists(local_directory): os.mkdir(local_directory) local_directory_created = True else: local_directory = tempfile.mkdtemp(prefix='scheduler-') local_directory_created = True if local_directory not in sys.path: sys.path.insert(0, local_directory) if sys.platform.startswith('linux'): import resource # module fails importing on Windows soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) if interface: if host: raise ValueError("Can not specify both interface and host") else: host = get_ip_interface(interface) addr = uri_from_host_port(host, port, 8786) loop = IOLoop.current() logger.info('-' * 47) services = {('http', http_port): HTTPScheduler} if _bokeh: with ignoring(ImportError): from distributed.bokeh.scheduler import BokehScheduler services[('bokeh', bokeh_port)] = partial(BokehScheduler, prefix=bokeh_prefix) scheduler = Scheduler(loop=loop, services=services, scheduler_file=scheduler_file, security=sec) scheduler.start(addr) preload_modules(preload, parameter=scheduler, file_dir=local_directory) logger.info('Local Directory: %26s', local_directory) logger.info('-' * 47) try: loop.start() loop.close() finally: scheduler.stop() if local_directory_created: shutil.rmtree(local_directory) logger.info("End scheduler at %r", addr)
def test_constructor_errors(): with pytest.raises(TypeError) as exc: Security(unknown_keyword="bar") assert "unknown_keyword" in str(exc.value)
def test_require_encryption(): """ Functional test for "require_encryption" setting. """ @gen.coroutine def handle_comm(comm): comm.abort() c = { "tls": { "ca-file": ca_file, "scheduler": { "key": key1, "cert": cert1 }, "worker": { "cert": keycert1 }, } } with new_config(c): sec = Security() c["require-encryption"] = True with new_config(c): sec2 = Security() for listen_addr in ["inproc://", "tls://"]: with listen( listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler")) as listener: comm = yield connect( listener.contact_address, connection_args=sec2.get_connection_args("worker"), ) comm.abort() with listen( listen_addr, handle_comm, connection_args=sec2.get_listen_args("scheduler")) as listener: comm = yield connect( listener.contact_address, connection_args=sec2.get_connection_args("worker"), ) comm.abort() @contextmanager def check_encryption_error(): with pytest.raises(RuntimeError) as excinfo: yield assert "encryption required" in str(excinfo.value) for listen_addr in ["tcp://"]: with listen( listen_addr, handle_comm, connection_args=sec.get_listen_args("scheduler")) as listener: comm = yield connect( listener.contact_address, connection_args=sec.get_connection_args("worker"), ) comm.abort() with pytest.raises(RuntimeError): yield connect( listener.contact_address, connection_args=sec2.get_connection_args("worker"), ) with pytest.raises(RuntimeError): listen( listen_addr, handle_comm, connection_args=sec2.get_listen_args("scheduler"), )
def main( host, port, bokeh_port, show, dashboard, bokeh, dashboard_prefix, use_xheaders, pid_file, local_directory, tls_ca_file, tls_cert, tls_key, dashboard_address, **kwargs ): g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) enable_proctitle_on_current() enable_proctitle_on_children() if bokeh_port is not None: warnings.warn( "The --bokeh-port flag has been renamed to --dashboard-address. " "Consider adding ``--dashboard-address :%d`` " % bokeh_port ) dashboard_address = bokeh_port if bokeh is not None: warnings.warn( "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. " ) dashboard = bokeh if port is None and (not host or not re.search(r":\d", host)): port = 8786 sec = Security( **{ k: v for k, v in [ ("tls_ca_file", tls_ca_file), ("tls_scheduler_cert", tls_cert), ("tls_scheduler_key", tls_key), ] if v is not None } ) if not host and (tls_ca_file or tls_cert or tls_key): host = "tls://" if pid_file: with open(pid_file, "w") as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) local_directory_created = False if local_directory: if not os.path.exists(local_directory): os.mkdir(local_directory) local_directory_created = True else: local_directory = tempfile.mkdtemp(prefix="scheduler-") local_directory_created = True if local_directory not in sys.path: sys.path.insert(0, local_directory) if sys.platform.startswith("linux"): import resource # module fails importing on Windows soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) limit = max(soft, hard // 2) resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) loop = IOLoop.current() logger.info("-" * 47) scheduler = Scheduler( loop=loop, security=sec, host=host, port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": {"prefix": dashboard_prefix}}, **kwargs, ) logger.info("Local Directory: %26s", local_directory) logger.info("-" * 47) install_signal_handlers(loop) async def run(): await scheduler await scheduler.finished() try: loop.run_sync(run) finally: scheduler.stop() if local_directory_created: shutil.rmtree(local_directory) logger.info("End scheduler at %r", scheduler.address)
def test_listen_args(): def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False def many_ciphers(ctx): if sys.version_info >= (3, 6): assert len(ctx.get_ciphers()) > 2 # Most likely c = { "tls": { "ca-file": ca_file, "scheduler": { "key": key1, "cert": cert1 }, "worker": { "cert": keycert1 }, } } with new_config(c): sec = Security() d = sec.get_listen_args("scheduler") assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) many_ciphers(ctx) d = sec.get_listen_args("worker") ctx = d["ssl_context"] basic_checks(ctx) many_ciphers(ctx) # No cert defined => no TLS d = sec.get_listen_args("client") assert d.get("ssl_context") is None # With more settings c["tls"]["ciphers"] = FORCED_CIPHER c["require-encryption"] = True with new_config(c): sec = Security() d = sec.get_listen_args("scheduler") assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) if sys.version_info >= (3, 6): supported_ciphers = ctx.get_ciphers() tls_12_ciphers = [ c for c in supported_ciphers if c["protocol"] == "TLSv1.2" ] assert len(tls_12_ciphers) == 1 tls_13_ciphers = [ c for c in supported_ciphers if c["protocol"] == "TLSv1.3" ] if len(tls_13_ciphers): assert len(tls_13_ciphers) == 3
def test_repr(): with new_config({}): sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem") assert ( repr(sec) == "Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')")
def main(scheduler, host, worker_port, listen_address, contact_address, nanny_port, nthreads, nprocs, nanny, name, pid_file, resources, dashboard, bokeh, bokeh_port, scheduler_file, dashboard_prefix, tls_ca_file, tls_cert, tls_key, dashboard_address, **kwargs): g0, g1, g2 = gc.get_threshold( ) # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) enable_proctitle_on_current() enable_proctitle_on_children() if bokeh_port is not None: warnings.warn( "The --bokeh-port flag has been renamed to --dashboard-address. " "Consider adding ``--dashboard-address :%d`` " % bokeh_port) dashboard_address = bokeh_port if bokeh is not None: warnings.warn( "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. " ) dashboard = bokeh sec = Security( **{ k: v for k, v in [ ("tls_ca_file", tls_ca_file), ("tls_worker_cert", tls_cert), ("tls_worker_key", tls_key), ] if v is not None }) if nprocs > 1 and worker_port != 0: logger.error( "Failed to launch worker. You cannot use the --port argument when nprocs > 1." ) exit(1) if nprocs > 1 and not nanny: logger.error( "Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1." ) exit(1) if contact_address and not listen_address: logger.error( "Failed to launch worker. " "Must specify --listen-address when --contact-address is given") exit(1) if nprocs > 1 and listen_address: logger.error("Failed to launch worker. " "You cannot specify --listen-address when nprocs > 1.") exit(1) if (worker_port or host) and listen_address: logger.error( "Failed to launch worker. " "You cannot specify --listen-address when --worker-port or --host is given." ) exit(1) try: if listen_address: (host, worker_port) = get_address_host_port(listen_address, strict=True) if contact_address: # we only need this to verify it is getting parsed (_, _) = get_address_host_port(contact_address, strict=True) else: # if contact address is not present we use the listen_address for contact contact_address = listen_address except ValueError as e: logger.error("Failed to launch worker. " + str(e)) exit(1) if nanny: port = nanny_port else: port = worker_port if not nthreads: nthreads = CPU_COUNT // nprocs if pid_file: with open(pid_file, "w") as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) if resources: resources = resources.replace(",", " ").split() resources = dict(pair.split("=") for pair in resources) resources = valmap(float, resources) else: resources = None loop = IOLoop.current() if nanny: kwargs.update({ "worker_port": worker_port, "listen_address": listen_address }) t = Nanny else: if nanny_port: kwargs["service_ports"] = {"nanny": nanny_port} t = Worker if (not scheduler and not scheduler_file and dask.config.get("scheduler-address", None) is None): raise ValueError("Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786") with ignoring(TypeError, ValueError): name = int(name) nannies = [ t(scheduler, scheduler_file=scheduler_file, nthreads=nthreads, loop=loop, resources=resources, security=sec, contact_address=contact_address, host=host, port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"dashboard": { "prefix": dashboard_prefix }}, name=name if nprocs == 1 or name is None or name == "" else str(name) + "-" + str(i), **kwargs) for i in range(nprocs) ] @gen.coroutine def close_all(): # Unregister all workers from scheduler if nanny: yield [n.close(timeout=2) for n in nannies] def on_signal(signum): logger.info("Exiting on signal %d", signum) close_all() @gen.coroutine def run(): yield nannies yield [n.finished() for n in nannies] install_signal_handlers(loop, cleanup=on_signal) try: loop.run_sync(run) except TimeoutError: # We already log the exception in nanny / worker. Don't do it again. raise TimeoutError("Timed out starting worker.") from None except KeyboardInterrupt: pass finally: logger.info("End worker")
def test_repr_local_keys(): sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem") representation = repr(sec) assert "ca.pem" in representation assert "scert.pem" in representation
MET = df['MET_pt'] # We can define a new key for cutflow (in this case 'all events'). Then we can put values into it. We need += because it's per-chunk (demonstrated below) output['cutflow']['all events'] += MET.size output['cutflow']['number of chunks'] += 1 # This fills our histogram once our data is collected. Always use .flatten() to make sure the array is reduced. The output key will be as defined in __init__ for self._accumulator; the hist key ('MET=') will be defined in the bin. output['MET'].fill(dataset=dataset, MET=MET.flatten()) return output def postprocess(self, accumulator): return accumulator sec_dask = Security(tls_ca_file='/etc/cmsaf-secrets/ca.pem', tls_worker_cert='/etc/cmsaf-secrets/usercert.pem', tls_worker_key='/etc/cmsaf-secrets/usercert.pem', tls_client_cert='/etc/cmsaf-secrets/usercert.pem', tls_client_key='/etc/cmsaf-secrets/usercert.pem', tls_scheduler_cert='/etc/cmsaf-secrets/hostcert.pem', tls_scheduler_key='/etc/cmsaf-secrets/hostcert.pem', require_encryption=True) HTCondorJob.submit_command = "condor_submit -spool" cluster = HTCondorCluster( cores=4, memory="2GB", disk="1GB", log_directory="logs", silence_logs="debug", scheduler_options={ "dashboard_address": "8786", "port": 8787,
async def test_require_encryption(): """ Functional test for "require_encryption" setting. """ async def handle_comm(comm): comm.abort() c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, "distributed.comm.tls.scheduler.cert": cert1, "distributed.comm.tls.worker.cert": keycert1, } with dask.config.set(c): sec = Security() c["distributed.comm.require-encryption"] = True with dask.config.set(c): sec2 = Security() for listen_addr in ["inproc://", "tls://"]: async with listen(listen_addr, handle_comm, **sec.get_listen_args("scheduler")) as listener: comm = await connect(listener.contact_address, **sec2.get_connection_args("worker")) comm.abort() async with listen(listen_addr, handle_comm, **sec2.get_listen_args("scheduler")) as listener: comm = await connect(listener.contact_address, **sec2.get_connection_args("worker")) comm.abort() @contextmanager def check_encryption_error(): with pytest.raises(RuntimeError) as excinfo: yield assert "encryption required" in str(excinfo.value) for listen_addr in ["tcp://"]: async with listen(listen_addr, handle_comm, **sec.get_listen_args("scheduler")) as listener: comm = await connect(listener.contact_address, **sec.get_connection_args("worker")) comm.abort() with pytest.raises(RuntimeError): await connect(listener.contact_address, **sec2.get_connection_args("worker")) with pytest.raises(RuntimeError): listen(listen_addr, handle_comm, **sec2.get_listen_args("scheduler"))
def __init__( self, n_workers=1, # Cluster keywords loop=None, security=None, silence_logs="error", name=None, # Scheduler-only keywords scheduler_options=None, # Options for both scheduler and workers interface=None, protocol="tcp://", # Job keywords config_name=None, # Some SLURM only keywords scheduler_file: str = None, local_scheduler: bool = False, timeout: int = 500, **kwargs): self.status = "created" self.name = name if scheduler_file is None: scheduler_file = "~/scheduler.json" scheduler_file = os.path.expanduser(scheduler_file) if os.path.exists(scheduler_file): raise ValueError("A scheduler file was already found at " + scheduler_file) # if interface is None: # interface = dask.config.get("jobqueue.%s.interface" % config_name) if scheduler_options is None: scheduler_options = {} default_scheduler_options = { "protocol": protocol, "dashboard_address": ":8787", "security": security, } if local_scheduler: scheduler_options = dict(default_scheduler_options, **scheduler_options) self.scheduler_spec = { "cls": Scheduler, # Use local scheduler for now "options": scheduler_options, } self.scheduler_file = scheduler_file kwargs["config_name"] = config_name kwargs["interface"] = interface self.protocol = protocol self.security = security or Security() self._kwargs = kwargs self.worker = None self.local = local_scheduler default_worker_options = { "cores": 1, "memory": "1GB", "python": shutil.which("python"), "dask_path": shutil.which("dask-mpi"), "threads_per_cpu": 4, "time": None, "constraint": None, "qos": None, "memory_limit": "0.5" } self.n_workers = 0 self.worker_fields = default_worker_options.keys() self.worker_options = default_worker_options if kwargs is not None: self.worker_options.update({ key: value for key, value in kwargs.items() if value is not None }) print(kwargs) self._loop_runner = LoopRunner(asynchronous=True) self._loop = self._loop_runner.loop self._lock = asyncio.Lock() self._loop_runner.start() self._supports_scaling = False self.worker_count = int(n_workers) self.nodes = int(self.worker_options["cores"]) self.memory = self.worker_options["memory"] self.threads_per_cpu = int(self.worker_options["threads_per_cpu"]) self.time = self.worker_options["time"] self.constraint = self.worker_options["constraint"] self.qos = self.worker_options["qos"] self.dask_path = self.worker_options["dask_path"] self.memory_limit = self.worker_options["memory_limit"] # A really dirty hack to ensure that mpirun is in the desired path self.mpirun_path = os.path.join(os.path.dirname(self.dask_path), "mpirun") self.timeout = int(timeout) print(self.dask_path) if self.worker_options["python"] is None: raise ValueError( "Python path was unspecified and was not found in the environment path" ) if self.dask_path is None: raise ValueError( "Dask MPI path was unspecified and was not found in the environment path" ) super().__init__(asynchronous=True)
def __init__( self, n_workers=0, job_cls: Job = None, # Cluster keywords loop=None, security=None, shared_temp_directory=None, silence_logs="error", name=None, asynchronous=False, # Scheduler-only keywords dashboard_address=None, host=None, scheduler_options=None, scheduler_cls=Scheduler, # Use local scheduler for now # Options for both scheduler and workers interface=None, protocol=None, # Job keywords config_name=None, **job_kwargs): self.status = Status.created default_job_cls = getattr(type(self), "job_cls", None) self.job_cls = default_job_cls if job_cls is not None: self.job_cls = job_cls if self.job_cls is None: raise ValueError( "You need to specify a Job type. Two cases:\n" "- you are inheriting from JobQueueCluster (most likely): you need to add a 'job_cls' class variable " "in your JobQueueCluster-derived class {}\n" "- you are using JobQueueCluster directly (less likely, only useful for tests): " "please explicitly pass a Job type through the 'job_cls' parameter." .format(type(self))) if dashboard_address is not None: raise ValueError( "Please pass 'dashboard_address' through 'scheduler_options': use\n" 'cluster = {0}(..., scheduler_options={{"dashboard_address": ":12345"}}) rather than\n' 'cluster = {0}(..., dashboard_address="12435")'.format( self.__class__.__name__)) if host is not None: raise ValueError( "Please pass 'host' through 'scheduler_options': use\n" 'cluster = {0}(..., scheduler_options={{"host": "your-host"}}) rather than\n' 'cluster = {0}(..., host="your-host")'.format( self.__class__.__name__)) default_config_name = self.job_cls.default_config_name() if config_name is None: config_name = default_config_name if interface is None: interface = dask.config.get("jobqueue.%s.interface" % config_name) if scheduler_options is None: scheduler_options = dask.config.get( "jobqueue.%s.scheduler-options" % config_name, {}) if protocol is None and security is not None: protocol = "tls://" if security is True: try: security = Security.temporary() except ImportError: raise ImportError( "In order to use TLS without pregenerated certificates `cryptography` is required," "please install it using either pip or conda") default_scheduler_options = { "protocol": protocol, "dashboard_address": ":8787", "security": security, } # scheduler_options overrides parameters common to both workers and scheduler scheduler_options = dict(default_scheduler_options, **scheduler_options) # Use the same network interface as the workers if scheduler ip has not # been set through scheduler_options via 'host' or 'interface' if "host" not in scheduler_options and "interface" not in scheduler_options: scheduler_options["interface"] = interface scheduler = { "cls": scheduler_cls, "options": scheduler_options, } if shared_temp_directory is None: shared_temp_directory = dask.config.get( "jobqueue.%s.shared-temp-directory" % config_name) self.shared_temp_directory = shared_temp_directory job_kwargs["config_name"] = config_name job_kwargs["interface"] = interface job_kwargs["protocol"] = protocol job_kwargs["security"] = self._get_worker_security(security) self._job_kwargs = job_kwargs worker = {"cls": self.job_cls, "options": self._job_kwargs} if "processes" in self._job_kwargs and self._job_kwargs[ "processes"] > 1: worker["group"] = [ "-" + str(i) for i in range(self._job_kwargs["processes"]) ] self._dummy_job # trigger property to ensure that the job is valid super().__init__( scheduler=scheduler, worker=worker, loop=loop, security=security, silence_logs=silence_logs, asynchronous=asynchronous, name=name, ) if n_workers: self.scale(n_workers)
def __init__( self, scheduler_ip=None, scheduler_port=None, scheduler_file=None, worker_port=0, nthreads=None, loop=None, local_dir=None, local_directory=None, services=None, name=None, memory_limit="auto", reconnect=True, validate=False, quiet=False, resources=None, silence_logs=None, death_timeout=None, preload=None, preload_argv=None, preload_nanny=None, preload_nanny_argv=None, security=None, contact_address=None, listen_address=None, worker_class=None, env=None, interface=None, host=None, port=None, protocol=None, config=None, **worker_kwargs, ): self._setup_logging(logger) self.loop = loop or IOLoop.current() if isinstance(security, dict): security = Security(**security) self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") if local_dir is not None: warnings.warn("The local_dir keyword has moved to local_directory") local_directory = local_dir if local_directory is None: local_directory = dask.config.get( "temporary-directory") or os.getcwd() self._original_local_dir = local_directory local_directory = os.path.join(local_directory, "dask-worker-space") else: self._original_local_dir = local_directory self.local_directory = local_directory if not os.path.exists(self.local_directory): os.makedirs(self.local_directory, exist_ok=True) self.preload = preload if self.preload is None: self.preload = dask.config.get("distributed.worker.preload") self.preload_argv = preload_argv if self.preload_argv is None: self.preload_argv = dask.config.get( "distributed.worker.preload-argv") if preload_nanny is None: preload_nanny = dask.config.get("distributed.nanny.preload") if preload_nanny_argv is None: preload_nanny_argv = dask.config.get( "distributed.nanny.preload-argv") self.preloads = preloading.process_preloads( self, preload_nanny, preload_nanny_argv, file_dir=self.local_directory) if scheduler_file: cfg = json_load_robust(scheduler_file) self.scheduler_addr = cfg["address"] elif scheduler_ip is None and dask.config.get("scheduler-address"): self.scheduler_addr = dask.config.get("scheduler-address") elif scheduler_port is None: self.scheduler_addr = coerce_to_address(scheduler_ip) else: self.scheduler_addr = coerce_to_address( (scheduler_ip, scheduler_port)) if protocol is None: protocol_address = self.scheduler_addr.split("://") if len(protocol_address) == 2: protocol = protocol_address[0] self._given_worker_port = worker_port self.nthreads = nthreads or CPU_COUNT self.reconnect = reconnect self.validate = validate self.resources = resources self.death_timeout = parse_timedelta(death_timeout) self.Worker = Worker if worker_class is None else worker_class config_environ = dask.config.get("distributed.nanny.environ", {}) if not isinstance(config_environ, dict): raise TypeError( "distributed.nanny.environ configuration must be of type dict. " f"Instead got {type(config_environ)}") self.env = config_environ.copy() for k in self.env: if k in os.environ: self.env[k] = os.environ[k] if env: self.env.update(env) self.env = {k: str(v) for k, v in self.env.items()} self.config = config or dask.config.config worker_kwargs.update({ "port": worker_port, "interface": interface, "protocol": protocol, "host": host, }) self.worker_kwargs = worker_kwargs self.contact_address = contact_address self.services = services self.name = name self.quiet = quiet self.auto_restart = True if silence_logs: silence_logging(level=silence_logs) self.silence_logs = silence_logs handlers = { "instantiate": self.instantiate, "kill": self.kill, "restart": self.restart, # cannot call it 'close' on the rpc side for naming conflict "get_logs": self.get_logs, "terminate": self.close, "close_gracefully": self.close_gracefully, "run": self.run, "plugin_add": self.plugin_add, "plugin_remove": self.plugin_remove, } self.plugins: dict[str, NannyPlugin] = {} super().__init__(handlers=handlers, io_loop=self.loop, connection_args=self.connection_args) self.scheduler = self.rpc(self.scheduler_addr) self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit) if (not host and not interface and not self.scheduler_addr.startswith("inproc://")): host = get_ip(get_address_host(self.scheduler.address)) self._start_port = port self._start_host = host self._interface = interface self._protocol = protocol self._listen_address = listen_address Nanny._instances.add(self) self.status = Status.init
def main( scheduler, host, nthreads, name, memory_limit, device_memory_limit, rmm_pool_size, pid_file, resources, dashboard, dashboard_address, local_directory, scheduler_file, interface, death_timeout, preload, dashboard_prefix, tls_ca_file, tls_cert, tls_key, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, net_devices, **kwargs, ): enable_proctitle_on_current() enable_proctitle_on_children() if tls_ca_file and tls_cert and tls_key: sec = Security( tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key ) else: sec = None try: nprocs = len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) except KeyError: nprocs = get_n_gpus() if not nthreads: nthreads = min(1, multiprocessing.cpu_count() // nprocs) memory_limit = parse_memory_limit(memory_limit, nthreads, total_cores=nprocs) if pid_file: with open(pid_file, "w") as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) services = {} if dashboard: try: from distributed.dashboard import BokehWorker except ImportError: pass else: if dashboard_prefix: result = (BokehWorker, {"prefix": dashboard_prefix}) else: result = BokehWorker services[("dashboard", dashboard_address)] = result if resources: resources = resources.replace(",", " ").split() resources = dict(pair.split("=") for pair in resources) resources = valmap(float, resources) else: resources = None loop = IOLoop.current() preload_argv = kwargs.get("preload_argv", []) kwargs = {"worker_port": None, "listen_address": None} t = Nanny if not scheduler and not scheduler_file and "scheduler-address" not in config: raise ValueError( "Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786" ) if interface: if host: raise ValueError("Can not specify both interface and host") else: host = get_ip_interface(interface) if rmm_pool_size is not None: try: import rmm # noqa F401 except ImportError: raise ValueError( "RMM pool requested but module 'rmm' is not available. " "For installation instructions, please see " "https://github.com/rapidsai/rmm" ) # pragma: no cover rmm_pool_size = parse_bytes(rmm_pool_size) nannies = [ t( scheduler, scheduler_file=scheduler_file, nthreads=nthreads, services=services, loop=loop, resources=resources, memory_limit=memory_limit, interface=get_ucx_net_devices( cuda_device_index=i, ucx_net_devices=net_devices, get_openfabrics=False, get_network=True, ), preload=(list(preload) or []) + ["dask_cuda.initialize"], preload_argv=(list(preload_argv) or []) + ["--create-cuda-context"], security=sec, env={"CUDA_VISIBLE_DEVICES": cuda_visible_devices(i)}, plugins={CPUAffinity(get_cpu_affinity(i)), RMMPool(rmm_pool_size)}, name=name if nprocs == 1 or not name else name + "-" + str(i), local_directory=local_directory, config={ "ucx": get_ucx_config( enable_tcp_over_ucx=enable_tcp_over_ucx, enable_infiniband=enable_infiniband, enable_nvlink=enable_nvlink, enable_rdmacm=enable_rdmacm, net_devices=net_devices, cuda_device_index=i, ) }, data=( DeviceHostFile, { "device_memory_limit": get_device_total_memory(index=i) if (device_memory_limit == "auto" or device_memory_limit == int(0)) else parse_bytes(device_memory_limit), "memory_limit": memory_limit, "local_directory": local_directory, }, ), **kwargs, ) for i in range(nprocs) ] @gen.coroutine def close_all(): # Unregister all workers from scheduler yield [n._close(timeout=2) for n in nannies] def on_signal(signum): logger.info("Exiting on signal %d", signum) close_all() @gen.coroutine def run(): yield nannies yield [n.finished() for n in nannies] install_signal_handlers(loop, cleanup=on_signal) try: loop.run_sync(run) except (KeyboardInterrupt, TimeoutError): pass finally: logger.info("End worker")
def test_require_encryption(): """ Functional test for "require_encryption" setting. """ @gen.coroutine def handle_comm(comm): comm.abort() c = { 'tls': { 'ca-file': ca_file, 'scheduler': { 'key': key1, 'cert': cert1, }, 'worker': { 'cert': keycert1, }, }, } with new_config(c): sec = Security() c['require-encryption'] = True with new_config(c): sec2 = Security() for listen_addr in ['inproc://', 'tls://']: with listen( listen_addr, handle_comm, connection_args=sec.get_listen_args('scheduler')) as listener: comm = yield connect( listener.contact_address, connection_args=sec2.get_connection_args('worker')) comm.abort() with listen( listen_addr, handle_comm, connection_args=sec2.get_listen_args('scheduler')) as listener: comm = yield connect( listener.contact_address, connection_args=sec2.get_connection_args('worker')) comm.abort() @contextmanager def check_encryption_error(): with pytest.raises(RuntimeError) as excinfo: yield assert "encryption required" in str(excinfo.value) for listen_addr in ['tcp://']: with listen( listen_addr, handle_comm, connection_args=sec.get_listen_args('scheduler')) as listener: comm = yield connect( listener.contact_address, connection_args=sec.get_connection_args('worker')) comm.abort() with pytest.raises(RuntimeError): yield connect( listener.contact_address, connection_args=sec2.get_connection_args('worker')) with pytest.raises(RuntimeError): listen(listen_addr, handle_comm, connection_args=sec2.get_listen_args('scheduler'))
def main(scheduler, host, worker_port, http_port, nanny_port, nthreads, nprocs, nanny, name, memory_limit, pid_file, reconnect, resources, bokeh, bokeh_port, local_directory, scheduler_file, interface, death_timeout, preload, bokeh_prefix, tls_ca_file, tls_cert, tls_key): sec = Security( tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key, ) if nanny: port = nanny_port else: port = worker_port if nprocs > 1 and worker_port != 0: logger.error( "Failed to launch worker. You cannot use the --port argument when nprocs > 1." ) exit(1) if nprocs > 1 and name: logger.error( "Failed to launch worker. You cannot use the --name argument when nprocs > 1." ) exit(1) if nprocs > 1 and not nanny: logger.error( "Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1." ) exit(1) if not nthreads: nthreads = _ncores // nprocs if pid_file: with open(pid_file, 'w') as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) services = {('http', http_port): HTTPWorker} if bokeh: try: from distributed.bokeh.worker import BokehWorker except ImportError: pass else: if bokeh_prefix: result = (BokehWorker, {'prefix': bokeh_prefix}) else: result = BokehWorker services[('bokeh', bokeh_port)] = result if resources: resources = resources.replace(',', ' ').split() resources = dict(pair.split('=') for pair in resources) resources = valmap(float, resources) else: resources = None loop = IOLoop.current() if nanny: kwargs = {'worker_port': worker_port} t = Nanny else: kwargs = {} if nanny_port: kwargs['service_ports'] = {'nanny': nanny_port} t = Worker if scheduler_file: while not os.path.exists(scheduler_file): sleep(0.01) for i in range(10): try: with open(scheduler_file) as f: cfg = json.load(f) scheduler = cfg['address'] break except (ValueError, KeyError): # race with scheduler on file sleep(0.01) if not scheduler: raise ValueError("Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786") if interface: if host: raise ValueError("Can not specify both interface and host") else: host = get_ip_interface(interface) if host or port: addr = uri_from_host_port(host, port, 0) else: # Choose appropriate address for scheduler addr = None nannies = [ t(scheduler, ncores=nthreads, services=services, name=name, loop=loop, resources=resources, memory_limit=memory_limit, reconnect=reconnect, local_dir=local_directory, death_timeout=death_timeout, preload=preload, security=sec, **kwargs) for i in range(nprocs) ] @gen.coroutine def close_all(): try: if nanny: yield [n._close(timeout=2) for n in nannies] finally: loop.stop() def handle_signal(signum, frame): logger.info("Exiting on signal %d", signum) if loop._running: loop.add_callback_from_signal(loop.stop) else: exit(0) # NOTE: We can't use the generic install_signal_handlers() function from # distributed.cli.utils because we're handling the signal differently. signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) for n in nannies: n.start(addr) @gen.coroutine def run(): while all(n.status != 'closed' for n in nannies): yield gen.sleep(0.2) try: loop.run_sync(run) except (KeyboardInterrupt, TimeoutError): pass finally: logger.info("End worker") # Clean exit: unregister all workers from scheduler loop.run_sync(close_all)
import pytest import dask from distributed import Client, Scheduler, Worker from distributed.comm import connect, listen, ws from distributed.comm.core import FatalCommClosedError from distributed.comm.registry import backends, get_backend from distributed.security import Security from distributed.utils_test import ( # noqa: F401 cleanup, gen_cluster, get_client_ssl_context, get_server_ssl_context, inc, ) from .test_comms import check_tls_extra security = Security.temporary() def test_registered(): assert "ws" in backends backend = get_backend("ws") assert isinstance(backend, ws.WSBackend) @pytest.mark.asyncio async def test_listen_connect(cleanup): async def handle_comm(comm): while True: msg = await comm.read() await comm.write(msg)
def main( scheduler, host, worker_port, listen_address, contact_address, nanny_port, nthreads, nprocs, nanny, name, memory_limit, pid_file, reconnect, resources, dashboard, bokeh_port, local_directory, scheduler_file, interface, protocol, death_timeout, preload, preload_argv, dashboard_prefix, tls_ca_file, tls_cert, tls_key, dashboard_address, ): g0, g1, g2 = gc.get_threshold( ) # https://github.com/dask/distributed/issues/1653 gc.set_threshold(g0 * 3, g1 * 3, g2 * 3) enable_proctitle_on_current() enable_proctitle_on_children() if bokeh_port is not None: warnings.warn( "The --bokeh-port flag has been renamed to --dashboard-address. " "Consider adding ``--dashboard-address :%d`` " % bokeh_port) dashboard_address = bokeh_port sec = Security(tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key) if nprocs > 1 and worker_port != 0: logger.error( "Failed to launch worker. You cannot use the --port argument when nprocs > 1." ) exit(1) if nprocs > 1 and not nanny: logger.error( "Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1." ) exit(1) if contact_address and not listen_address: logger.error( "Failed to launch worker. " "Must specify --listen-address when --contact-address is given") exit(1) if nprocs > 1 and listen_address: logger.error("Failed to launch worker. " "You cannot specify --listen-address when nprocs > 1.") exit(1) if (worker_port or host) and listen_address: logger.error( "Failed to launch worker. " "You cannot specify --listen-address when --worker-port or --host is given." ) exit(1) try: if listen_address: (host, worker_port) = get_address_host_port(listen_address, strict=True) if contact_address: # we only need this to verify it is getting parsed (_, _) = get_address_host_port(contact_address, strict=True) else: # if contact address is not present we use the listen_address for contact contact_address = listen_address except ValueError as e: logger.error("Failed to launch worker. " + str(e)) exit(1) if nanny: port = nanny_port else: port = worker_port if not nthreads: nthreads = _ncores // nprocs if pid_file: with open(pid_file, "w") as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) services = {} if resources: resources = resources.replace(",", " ").split() resources = dict(pair.split("=") for pair in resources) resources = valmap(float, resources) else: resources = None loop = IOLoop.current() if nanny: kwargs = {"worker_port": worker_port, "listen_address": listen_address} t = Nanny else: kwargs = {} if nanny_port: kwargs["service_ports"] = {"nanny": nanny_port} t = Worker if (not scheduler and not scheduler_file and dask.config.get("scheduler-address", None) is None): raise ValueError("Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786") if death_timeout is not None: death_timeout = parse_timedelta(death_timeout, "s") nannies = [ t(scheduler, scheduler_file=scheduler_file, ncores=nthreads, services=services, loop=loop, resources=resources, memory_limit=memory_limit, reconnect=reconnect, local_dir=local_directory, death_timeout=death_timeout, preload=preload, preload_argv=preload_argv, security=sec, contact_address=contact_address, interface=interface, protocol=protocol, host=host, port=port, dashboard_address=dashboard_address if dashboard else None, service_kwargs={"bokhe": { "prefix": dashboard_prefix }}, name=name if nprocs == 1 or not name else name + "-" + str(i), **kwargs) for i in range(nprocs) ] @gen.coroutine def close_all(): # Unregister all workers from scheduler if nanny: yield [n.close(timeout=2) for n in nannies] def on_signal(signum): logger.info("Exiting on signal %d", signum) close_all() @gen.coroutine def run(): yield nannies while all(n.status != "closed" for n in nannies): yield gen.sleep(0.2) install_signal_handlers(loop, cleanup=on_signal) try: loop.run_sync(run) except (KeyboardInterrupt, TimeoutError): pass finally: logger.info("End worker")
def test_repr(): sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem") assert ( repr(sec) == "Security(require_encryption=False, tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')" )
def main(scheduler, host, worker_port, listen_address, contact_address, nanny_port, nthreads, nprocs, nanny, name, memory_limit, pid_file, reconnect, resources, bokeh, bokeh_port, local_directory, scheduler_file, interface, death_timeout, preload, bokeh_prefix, tls_ca_file, tls_cert, tls_key): sec = Security( tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key, ) if nprocs > 1 and worker_port != 0: logger.error( "Failed to launch worker. You cannot use the --port argument when nprocs > 1." ) exit(1) if nprocs > 1 and name: logger.error( "Failed to launch worker. You cannot use the --name argument when nprocs > 1." ) exit(1) if nprocs > 1 and not nanny: logger.error( "Failed to launch worker. You cannot use the --no-nanny argument when nprocs > 1." ) exit(1) if contact_address and not listen_address: logger.error( "Failed to launch worker. " "Must specify --listen-address when --contact-address is given") exit(1) if nprocs > 1 and listen_address: logger.error("Failed to launch worker. " "You cannot specify --listen-address when nprocs > 1.") exit(1) if (worker_port or host) and listen_address: logger.error( "Failed to launch worker. " "You cannot specify --listen-address when --worker-port or --host is given." ) exit(1) try: if listen_address: (host, worker_port) = get_address_host_port(listen_address, strict=True) if contact_address: # we only need this to verify it is getting parsed (_, _) = get_address_host_port(contact_address, strict=True) else: # if contact address is not present we use the listen_address for contact contact_address = listen_address except ValueError as e: logger.error("Failed to launch worker. " + str(e)) exit(1) if nanny: port = nanny_port else: port = worker_port if not nthreads: nthreads = _ncores // nprocs if pid_file: with open(pid_file, 'w') as f: f.write(str(os.getpid())) def del_pid_file(): if os.path.exists(pid_file): os.remove(pid_file) atexit.register(del_pid_file) services = {} if bokeh: try: from distributed.bokeh.worker import BokehWorker except ImportError: pass else: if bokeh_prefix: result = (BokehWorker, {'prefix': bokeh_prefix}) else: result = BokehWorker services[('bokeh', bokeh_port)] = result if resources: resources = resources.replace(',', ' ').split() resources = dict(pair.split('=') for pair in resources) resources = valmap(float, resources) else: resources = None loop = IOLoop.current() if nanny: kwargs = {'worker_port': worker_port, 'listen_address': listen_address} t = Nanny else: kwargs = {} if nanny_port: kwargs['service_ports'] = {'nanny': nanny_port} t = Worker if not scheduler and not scheduler_file: raise ValueError("Need to provide scheduler address like\n" "dask-worker SCHEDULER_ADDRESS:8786") if interface: if host: raise ValueError("Can not specify both interface and host") else: host = get_ip_interface(interface) if host or port: addr = uri_from_host_port(host, port, 0) else: # Choose appropriate address for scheduler addr = None nannies = [ t(scheduler, scheduler_file=scheduler_file, ncores=nthreads, services=services, name=name, loop=loop, resources=resources, memory_limit=memory_limit, reconnect=reconnect, local_dir=local_directory, death_timeout=death_timeout, preload=preload, security=sec, contact_address=contact_address, **kwargs) for i in range(nprocs) ] @gen.coroutine def close_all(): # Unregister all workers from scheduler if nanny: yield [n._close(timeout=2) for n in nannies] def on_signal(signum): logger.info("Exiting on signal %d", signum) close_all() @gen.coroutine def run(): yield [n.start(addr) for n in nannies] while all(n.status != 'closed' for n in nannies): yield gen.sleep(0.2) install_signal_handlers(loop, cleanup=on_signal) try: loop.run_sync(run) except (KeyboardInterrupt, TimeoutError): pass finally: logger.info("End worker")
def main( scheduler, host, nthreads, name, memory_limit, device_memory_limit, rmm_pool_size, rmm_managed_memory, pid_file, resources, dashboard, dashboard_address, local_directory, scheduler_file, interface, death_timeout, preload, dashboard_prefix, tls_ca_file, tls_cert, tls_key, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, net_devices, enable_jit_unspill, **kwargs, ): if tls_ca_file and tls_cert and tls_key: security = Security( tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key, ) else: security = None if isinstance(scheduler, str) and scheduler.startswith("-"): raise ValueError( "The scheduler address can't start with '-'. Please check " "your command line arguments, you probably attempted to use " "unsupported one. Scheduler address: %s" % scheduler ) worker = CUDAWorker( scheduler, host, nthreads, name, memory_limit, device_memory_limit, rmm_pool_size, rmm_managed_memory, pid_file, resources, dashboard, dashboard_address, local_directory, scheduler_file, interface, death_timeout, preload, dashboard_prefix, security, enable_tcp_over_ucx, enable_infiniband, enable_nvlink, enable_rdmacm, net_devices, enable_jit_unspill, **kwargs, ) async def on_signal(signum): logger.info("Exiting on signal %d", signum) await worker.close() async def run(): await worker await worker.finished() loop = IOLoop.current() install_signal_handlers(loop, cleanup=on_signal) try: loop.run_sync(run) except (KeyboardInterrupt, TimeoutError): pass finally: logger.info("End worker")