예제 #1
0
def test_attribute_error():
    sec = Security()
    assert hasattr(sec, 'tls_ca_file')
    with pytest.raises(AttributeError):
        sec.tls_foobar
    with pytest.raises(AttributeError):
        sec.tls_foobar = ""
예제 #2
0
def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    @gen.coroutine
    def handle_comm(comm):
        comm.abort()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()
    c['require-encryption'] = True
    with new_config(c):
        sec2 = Security()

    for listen_addr in ['inproc://', 'tls://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

        with listen(listen_addr, handle_comm,
                    connection_args=sec2.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ['tcp://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec.get_connection_args('worker'))
            comm.abort()

            with pytest.raises(RuntimeError):
                yield connect(listener.contact_address,
                              connection_args=sec2.get_connection_args('worker'))

        with pytest.raises(RuntimeError):
            listen(listen_addr, handle_comm,
                   connection_args=sec2.get_listen_args('scheduler'))
예제 #3
0
def test_listen_args():
    def basic_checks(ctx):
        assert ctx.verify_mode == ssl.CERT_REQUIRED
        assert ctx.check_hostname is False

    def many_ciphers(ctx):
        if sys.version_info >= (3, 6):
            assert len(ctx.get_ciphers()) > 2  # Most likely

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert not d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    d = sec.get_listen_args('worker')
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    # No cert defined => no TLS
    d = sec.get_listen_args('client')
    assert d.get('ssl_context') is None

    # With more settings
    c['tls']['ciphers'] = FORCED_CIPHER
    c['require-encryption'] = True

    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    if sys.version_info >= (3, 6):
        supported_ciphers = ctx.get_ciphers()
        tls_12_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.2']
        assert len(tls_12_ciphers) == 1
        tls_13_ciphers = [c for c in supported_ciphers if c['protocol'] == 'TLSv1.3']
        if len(tls_13_ciphers):
            assert len(tls_13_ciphers) == 3
예제 #4
0
def test_tls_listen_connect():
    """
    Functional test for TLS connection args.
    """
    @gen.coroutine
    def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith('tls://')
        yield comm.write('hello')
        yield comm.close()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    c['tls']['ciphers'] = FORCED_CIPHER
    with new_config(c):
        forced_cipher_sec = Security()

    with listen('tls://', handle_comm,
                connection_args=sec.get_listen_args('scheduler')) as listener:
        comm = yield connect(listener.contact_address,
                             connection_args=sec.get_connection_args('worker'))
        msg = yield comm.read()
        assert msg == 'hello'
        comm.abort()

        # No SSL context for client
        with pytest.raises(TypeError):
            yield connect(listener.contact_address,
                          connection_args=sec.get_connection_args('client'))

        # Check forced cipher
        comm = yield connect(listener.contact_address,
                             connection_args=forced_cipher_sec.get_connection_args('worker'))
        cipher, _, _, = comm.extra_info['cipher']
        assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS
        comm.abort()
예제 #5
0
def test_tls_config_for_role():
    c = {
        'tls': {
            'ca-file': 'ca.pem',
            'scheduler': {
                'key': 'skey.pem',
                'cert': 'scert.pem',
            },
            'worker': {
                'cert': 'wcert.pem',
            },
            'ciphers': FORCED_CIPHER,
        },
    }
    with new_config(c):
        sec = Security()
    t = sec.get_tls_config_for_role('scheduler')
    assert t == {
        'ca_file': 'ca.pem',
        'key': 'skey.pem',
        'cert': 'scert.pem',
        'ciphers': FORCED_CIPHER,
    }
    t = sec.get_tls_config_for_role('worker')
    assert t == {
        'ca_file': 'ca.pem',
        'key': None,
        'cert': 'wcert.pem',
        'ciphers': FORCED_CIPHER,
    }
    t = sec.get_tls_config_for_role('client')
    assert t == {
        'ca_file': 'ca.pem',
        'key': None,
        'cert': None,
        'ciphers': FORCED_CIPHER,
    }
    with pytest.raises(ValueError):
        sec.get_tls_config_for_role('supervisor')
예제 #6
0
def test_kwargs():
    c = {
        'tls': {
            'ca-file': 'ca.pem',
            'scheduler': {
                'key': 'skey.pem',
                'cert': 'scert.pem',
            },
        },
    }
    with new_config(c):
        sec = Security(tls_scheduler_cert='newcert.pem',
                       require_encryption=True,
                       tls_ca_file=None)
    assert sec.require_encryption is True
    # None value didn't override default
    assert sec.tls_ca_file == 'ca.pem'
    assert sec.tls_ciphers is None
    assert sec.tls_client_key is None
    assert sec.tls_client_cert is None
    assert sec.tls_scheduler_key == 'skey.pem'
    assert sec.tls_scheduler_cert == 'newcert.pem'
    assert sec.tls_worker_key is None
    assert sec.tls_worker_cert is None
예제 #7
0
def test_kwargs():
    c = {
        "tls": {
            "ca-file": "ca.pem",
            "scheduler": {
                "key": "skey.pem",
                "cert": "scert.pem"
            },
        }
    }
    with new_config(c):
        sec = Security(tls_scheduler_cert="newcert.pem",
                       require_encryption=True,
                       tls_ca_file=None)
    assert sec.require_encryption is True
    # None value didn't override default
    assert sec.tls_ca_file == "ca.pem"
    assert sec.tls_ciphers is None
    assert sec.tls_client_key is None
    assert sec.tls_client_cert is None
    assert sec.tls_scheduler_key == "skey.pem"
    assert sec.tls_scheduler_cert == "newcert.pem"
    assert sec.tls_worker_key is None
    assert sec.tls_worker_cert is None
예제 #8
0
def test_tls_config_for_role():
    c = {
        'tls': {
            'ca-file': 'ca.pem',
            'scheduler': {
                'key': 'skey.pem',
                'cert': 'scert.pem',
            },
            'worker': {
                'cert': 'wcert.pem',
            },
            'ciphers': FORCED_CIPHER,
        },
    }
    with new_config(c):
        sec = Security()
    t = sec.get_tls_config_for_role('scheduler')
    assert t == {
        'ca_file': 'ca.pem',
        'key': 'skey.pem',
        'cert': 'scert.pem',
        'ciphers': FORCED_CIPHER,
    }
    t = sec.get_tls_config_for_role('worker')
    assert t == {
        'ca_file': 'ca.pem',
        'key': None,
        'cert': 'wcert.pem',
        'ciphers': FORCED_CIPHER,
    }
    t = sec.get_tls_config_for_role('client')
    assert t == {
        'ca_file': 'ca.pem',
        'key': None,
        'cert': None,
        'ciphers': FORCED_CIPHER,
    }
    with pytest.raises(ValueError):
        sec.get_tls_config_for_role('supervisor')
    def __init__(
        self,
        n_workers: int = 0,
        worker_class: str = "dask.distributed.Nanny",
        worker_options: dict = {},
        scheduler_options: dict = {},
        docker_image="daskdev/dask:latest",
        docker_args: str = "",
        env_vars: dict = {},
        security: bool = True,
        protocol: str = None,
        **kwargs,
    ):
        if self.scheduler_class is None or self.worker_class is None:
            raise RuntimeError(
                "VMCluster is not intended to be used directly. See docstring for more info."
            )
        self._n_workers = n_workers

        if not security:
            self.security = None
        elif security is True:
            # True indicates self-signed temporary credentials should be used
            self.security = Security.temporary()
        elif not isinstance(security, Security):
            raise TypeError("security must be a Security object")
        else:
            self.security = security

        if protocol is None:
            if self.security and self.security.require_encryption:
                self.protocol = "tls"
            else:
                self.protocol = "tcp"
        else:
            self.protocol = protocol

        if self.security and self.security.require_encryption:
            dask.config.set({
                "distributed.comm.default-scheme":
                self.protocol,
                "distributed.comm.require-encryption":
                True,
                "distributed.comm.tls.ca-file":
                self.security.tls_ca_file,
                "distributed.comm.tls.scheduler.key":
                self.security.tls_scheduler_key,
                "distributed.comm.tls.scheduler.cert":
                self.security.tls_scheduler_cert,
                "distributed.comm.tls.worker.key":
                self.security.tls_worker_key,
                "distributed.comm.tls.worker.cert":
                self.security.tls_worker_cert,
                "distributed.comm.tls.client.key":
                self.security.tls_client_key,
                "distributed.comm.tls.client.cert":
                self.security.tls_client_cert,
            })

        image = self.scheduler_options.get("docker_image",
                                           False) or docker_image
        self.options["docker_image"] = image
        self.scheduler_options["docker_image"] = image
        self.scheduler_options["env_vars"] = env_vars
        self.scheduler_options["protocol"] = protocol
        self.scheduler_options["scheduler_options"] = scheduler_options
        self.worker_options["env_vars"] = env_vars
        self.options["docker_args"] = docker_args
        self.scheduler_options["docker_args"] = docker_args
        self.worker_options["docker_args"] = docker_args
        self.worker_options["docker_image"] = image
        self.worker_options["worker_class"] = worker_class
        self.worker_options["protocol"] = protocol
        self.worker_options["worker_options"] = worker_options
        self.uuid = str(uuid.uuid4())[:8]

        super().__init__(**kwargs, security=self.security)
예제 #10
0
def test_repr_temp_keys():
    sec = Security.temporary()
    representation = repr(sec)
    assert "Temporary (In-memory)" in representation
예제 #11
0
urllib_parse = six.moves.urllib_parse
urlparse = urllib_parse.urlparse
logger = logging.getLogger('distributed.preloading')


from dask_signal import *
from dask_usage import USAGE_INFO


log_dir = os.path.expanduser('~/dask/logs/') # ATTENTION: this is the remote location !
# TLS 模式注意事项: 1.路径为绝对路径, '~' 和通配符 必须先展开;    2.Client的 addr指定为tls协议格式"tls://host:port"
TLS_CA_FILE=os.path.expanduser('~/.dask/ca.crt')
TLS_CA_CERT=os.path.expanduser('~/.dask/ca.crt')
TLS_CA_KEY=os.path.expanduser('~/.dask/ca.key')

SECURITY = SEC = Security(tls_ca_file=TLS_CA_FILE, tls_scheduler_cert=TLS_CA_CERT, tls_scheduler_key=TLS_CA_KEY)
SECURITY_WORKER = Security(tls_ca_file=TLS_CA_FILE, tls_worker_cert=TLS_CA_CERT, tls_worker_key=TLS_CA_KEY)
SECURITY_CLIENT = Security(tls_ca_file=TLS_CA_FILE, tls_client_cert=TLS_CA_CERT, tls_client_key=TLS_CA_KEY, require_encryption=True)

SCHEDULER_PORT = 8786
GLOBAL_CLUSTER = list(map(lambda m:'tls://%s:%s'%(m, SCHEDULER_PORT), MASTERS))[0]


SSH_USER = '******'
SSH_MASTER_ip = 'gpu01.ops.zzyc.abael.com'
SSH_PKEY = os.path.expanduser('~/.ssh/id_rsa')
SSH_PUB = os.path.expanduser('~/.ssh/id_rsa.pub')
SSH_WORKER_python = '/usr/bin/python3.6'
SSH_PORT = 22
SSH_MASTER_port = 11111
SSH_NANNY_port = 22222
예제 #12
0
def main(
    scheduler,
    host,
    nthreads,
    name,
    memory_limit,
    device_memory_limit,
    rmm_pool_size,
    rmm_managed_memory,
    pid_file,
    resources,
    dashboard,
    dashboard_address,
    local_directory,
    scheduler_file,
    interface,
    death_timeout,
    preload,
    dashboard_prefix,
    tls_ca_file,
    tls_cert,
    tls_key,
    enable_tcp_over_ucx,
    enable_infiniband,
    enable_nvlink,
    enable_rdmacm,
    net_devices,
    **kwargs,
):
    if tls_ca_file and tls_cert and tls_key:
        security = Security(
            tls_ca_file=tls_ca_file,
            tls_worker_cert=tls_cert,
            tls_worker_key=tls_key,
        )
    else:
        security = None

    worker = CUDAWorker(
        scheduler,
        host,
        nthreads,
        name,
        memory_limit,
        device_memory_limit,
        rmm_pool_size,
        rmm_managed_memory,
        pid_file,
        resources,
        dashboard,
        dashboard_address,
        local_directory,
        scheduler_file,
        interface,
        death_timeout,
        preload,
        dashboard_prefix,
        security,
        enable_tcp_over_ucx,
        enable_infiniband,
        enable_nvlink,
        enable_rdmacm,
        net_devices,
        **kwargs,
    )

    async def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        await worker.close()

    async def run():
        await worker
        await worker.finished()

    loop = IOLoop.current()

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")
예제 #13
0
def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix,
         use_xheaders, pid_file, scheduler_file, interface, local_directory,
         preload, preload_argv, tls_ca_file, tls_cert, tls_key):

    enable_proctitle_on_current()
    enable_proctitle_on_children()

    sec = Security(
        tls_ca_file=tls_ca_file,
        tls_scheduler_cert=tls_cert,
        tls_scheduler_key=tls_key,
    )

    if not host and (tls_ca_file or tls_cert or tls_key):
        host = 'tls://'

    if pid_file:
        with open(pid_file, 'w') as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    local_directory_created = False
    if local_directory:
        if not os.path.exists(local_directory):
            os.mkdir(local_directory)
            local_directory_created = True
    else:
        local_directory = tempfile.mkdtemp(prefix='scheduler-')
        local_directory_created = True
    if local_directory not in sys.path:
        sys.path.insert(0, local_directory)

    if sys.platform.startswith('linux'):
        import resource  # module fails importing on Windows
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        limit = max(soft, hard // 2)
        resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))

    if interface:
        if host:
            raise ValueError("Can not specify both interface and host")
        else:
            host = get_ip_interface(interface)

    addr = uri_from_host_port(host, port, 8786)

    loop = IOLoop.current()
    logger.info('-' * 47)

    services = {}
    if _bokeh:
        with ignoring(ImportError):
            from distributed.bokeh.scheduler import BokehScheduler
            services[('bokeh', bokeh_port)] = (BokehScheduler, {
                'prefix': bokeh_prefix
            })
    scheduler = Scheduler(loop=loop,
                          services=services,
                          scheduler_file=scheduler_file,
                          security=sec)
    scheduler.start(addr)
    if not preload:
        preload = dask.config.get('distributed.scheduler.preload')
    if not preload_argv:
        preload_argv = dask.config.get('distributed.scheduler.preload-argv')
    preload_modules(preload,
                    parameter=scheduler,
                    file_dir=local_directory,
                    argv=preload_argv)

    logger.info('Local Directory: %26s', local_directory)
    logger.info('-' * 47)

    install_signal_handlers(loop)

    try:
        loop.start()
        loop.close()
    finally:
        scheduler.stop()
        if local_directory_created:
            shutil.rmtree(local_directory)

        logger.info("End scheduler at %r", addr)
예제 #14
0
def test_listen_args():
    def basic_checks(ctx):
        assert ctx.verify_mode == ssl.CERT_REQUIRED
        assert ctx.check_hostname is False

    def many_ciphers(ctx):
        if sys.version_info >= (3, 6):
            assert len(ctx.get_ciphers()) > 2  # Most likely

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert not d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    d = sec.get_listen_args('worker')
    ctx = d['ssl_context']
    basic_checks(ctx)
    many_ciphers(ctx)

    # No cert defined => no TLS
    d = sec.get_listen_args('client')
    assert d.get('ssl_context') is None

    # With more settings
    c['tls']['ciphers'] = FORCED_CIPHER
    c['require-encryption'] = True

    with new_config(c):
        sec = Security()

    d = sec.get_listen_args('scheduler')
    assert d['require_encryption']
    ctx = d['ssl_context']
    basic_checks(ctx)
    if sys.version_info >= (3, 6):
        supported_ciphers = ctx.get_ciphers()
        tls_12_ciphers = [
            c for c in supported_ciphers if c['protocol'] == 'TLSv1.2'
        ]
        assert len(tls_12_ciphers) == 1
        tls_13_ciphers = [
            c for c in supported_ciphers if c['protocol'] == 'TLSv1.3'
        ]
        if len(tls_13_ciphers):
            assert len(tls_13_ciphers) == 3
예제 #15
0
def main(host, port, http_port, bokeh_port, bokeh_internal_port, show, _bokeh,
         bokeh_whitelist, bokeh_prefix, use_xheaders, pid_file, scheduler_file,
         interface, local_directory, preload, prefix, tls_ca_file, tls_cert,
         tls_key):

    if bokeh_internal_port:
        print("The --bokeh-internal-port keyword has been removed.\n"
              "The internal bokeh server is now the default bokeh server.\n"
              "Use --bokeh-port %d instead" % bokeh_internal_port)
        sys.exit(1)

    if prefix:
        print("The --prefix keyword has moved to --bokeh-prefix")
        sys.exit(1)

    sec = Security(
        tls_ca_file=tls_ca_file,
        tls_scheduler_cert=tls_cert,
        tls_scheduler_key=tls_key,
    )

    if pid_file:
        with open(pid_file, 'w') as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    local_directory_created = False
    if local_directory:
        if not os.path.exists(local_directory):
            os.mkdir(local_directory)
            local_directory_created = True
    else:
        local_directory = tempfile.mkdtemp(prefix='scheduler-')
        local_directory_created = True
    if local_directory not in sys.path:
        sys.path.insert(0, local_directory)

    if sys.platform.startswith('linux'):
        import resource  # module fails importing on Windows
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        limit = max(soft, hard // 2)
        resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))

    if interface:
        if host:
            raise ValueError("Can not specify both interface and host")
        else:
            host = get_ip_interface(interface)

    addr = uri_from_host_port(host, port, 8786)

    loop = IOLoop.current()
    logger.info('-' * 47)

    services = {('http', http_port): HTTPScheduler}
    if _bokeh:
        with ignoring(ImportError):
            from distributed.bokeh.scheduler import BokehScheduler
            services[('bokeh', bokeh_port)] = partial(BokehScheduler,
                                                      prefix=bokeh_prefix)
    scheduler = Scheduler(loop=loop,
                          services=services,
                          scheduler_file=scheduler_file,
                          security=sec)
    scheduler.start(addr)
    preload_modules(preload, parameter=scheduler, file_dir=local_directory)

    logger.info('Local Directory: %26s', local_directory)
    logger.info('-' * 47)
    try:
        loop.start()
        loop.close()
    finally:
        scheduler.stop()
        if local_directory_created:
            shutil.rmtree(local_directory)

        logger.info("End scheduler at %r", addr)
예제 #16
0
def test_constructor_errors():
    with pytest.raises(TypeError) as exc:
        Security(unknown_keyword="bar")

    assert "unknown_keyword" in str(exc.value)
예제 #17
0
def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    @gen.coroutine
    def handle_comm(comm):
        comm.abort()

    c = {
        "tls": {
            "ca-file": ca_file,
            "scheduler": {
                "key": key1,
                "cert": cert1
            },
            "worker": {
                "cert": keycert1
            },
        }
    }
    with new_config(c):
        sec = Security()
    c["require-encryption"] = True
    with new_config(c):
        sec2 = Security()

    for listen_addr in ["inproc://", "tls://"]:
        with listen(
                listen_addr,
                handle_comm,
                connection_args=sec.get_listen_args("scheduler")) as listener:
            comm = yield connect(
                listener.contact_address,
                connection_args=sec2.get_connection_args("worker"),
            )
            comm.abort()

        with listen(
                listen_addr,
                handle_comm,
                connection_args=sec2.get_listen_args("scheduler")) as listener:
            comm = yield connect(
                listener.contact_address,
                connection_args=sec2.get_connection_args("worker"),
            )
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ["tcp://"]:
        with listen(
                listen_addr,
                handle_comm,
                connection_args=sec.get_listen_args("scheduler")) as listener:
            comm = yield connect(
                listener.contact_address,
                connection_args=sec.get_connection_args("worker"),
            )
            comm.abort()

            with pytest.raises(RuntimeError):
                yield connect(
                    listener.contact_address,
                    connection_args=sec2.get_connection_args("worker"),
                )

        with pytest.raises(RuntimeError):
            listen(
                listen_addr,
                handle_comm,
                connection_args=sec2.get_listen_args("scheduler"),
            )
예제 #18
0
def main(
    host,
    port,
    bokeh_port,
    show,
    dashboard,
    bokeh,
    dashboard_prefix,
    use_xheaders,
    pid_file,
    local_directory,
    tls_ca_file,
    tls_cert,
    tls_key,
    dashboard_address,
    **kwargs
):
    g0, g1, g2 = gc.get_threshold()  # https://github.com/dask/distributed/issues/1653
    gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)

    enable_proctitle_on_current()
    enable_proctitle_on_children()

    if bokeh_port is not None:
        warnings.warn(
            "The --bokeh-port flag has been renamed to --dashboard-address. "
            "Consider adding ``--dashboard-address :%d`` " % bokeh_port
        )
        dashboard_address = bokeh_port
    if bokeh is not None:
        warnings.warn(
            "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. "
        )
        dashboard = bokeh

    if port is None and (not host or not re.search(r":\d", host)):
        port = 8786

    sec = Security(
        **{
            k: v
            for k, v in [
                ("tls_ca_file", tls_ca_file),
                ("tls_scheduler_cert", tls_cert),
                ("tls_scheduler_key", tls_key),
            ]
            if v is not None
        }
    )

    if not host and (tls_ca_file or tls_cert or tls_key):
        host = "tls://"

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    local_directory_created = False
    if local_directory:
        if not os.path.exists(local_directory):
            os.mkdir(local_directory)
            local_directory_created = True
    else:
        local_directory = tempfile.mkdtemp(prefix="scheduler-")
        local_directory_created = True
    if local_directory not in sys.path:
        sys.path.insert(0, local_directory)

    if sys.platform.startswith("linux"):
        import resource  # module fails importing on Windows

        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        limit = max(soft, hard // 2)
        resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))

    loop = IOLoop.current()
    logger.info("-" * 47)

    scheduler = Scheduler(
        loop=loop,
        security=sec,
        host=host,
        port=port,
        dashboard_address=dashboard_address if dashboard else None,
        service_kwargs={"dashboard": {"prefix": dashboard_prefix}},
        **kwargs,
    )
    logger.info("Local Directory: %26s", local_directory)
    logger.info("-" * 47)

    install_signal_handlers(loop)

    async def run():
        await scheduler
        await scheduler.finished()

    try:
        loop.run_sync(run)
    finally:
        scheduler.stop()
        if local_directory_created:
            shutil.rmtree(local_directory)

        logger.info("End scheduler at %r", scheduler.address)
예제 #19
0
def test_listen_args():
    def basic_checks(ctx):
        assert ctx.verify_mode == ssl.CERT_REQUIRED
        assert ctx.check_hostname is False

    def many_ciphers(ctx):
        if sys.version_info >= (3, 6):
            assert len(ctx.get_ciphers()) > 2  # Most likely

    c = {
        "tls": {
            "ca-file": ca_file,
            "scheduler": {
                "key": key1,
                "cert": cert1
            },
            "worker": {
                "cert": keycert1
            },
        }
    }
    with new_config(c):
        sec = Security()

    d = sec.get_listen_args("scheduler")
    assert not d["require_encryption"]
    ctx = d["ssl_context"]
    basic_checks(ctx)
    many_ciphers(ctx)

    d = sec.get_listen_args("worker")
    ctx = d["ssl_context"]
    basic_checks(ctx)
    many_ciphers(ctx)

    # No cert defined => no TLS
    d = sec.get_listen_args("client")
    assert d.get("ssl_context") is None

    # With more settings
    c["tls"]["ciphers"] = FORCED_CIPHER
    c["require-encryption"] = True

    with new_config(c):
        sec = Security()

    d = sec.get_listen_args("scheduler")
    assert d["require_encryption"]
    ctx = d["ssl_context"]
    basic_checks(ctx)
    if sys.version_info >= (3, 6):
        supported_ciphers = ctx.get_ciphers()
        tls_12_ciphers = [
            c for c in supported_ciphers if c["protocol"] == "TLSv1.2"
        ]
        assert len(tls_12_ciphers) == 1
        tls_13_ciphers = [
            c for c in supported_ciphers if c["protocol"] == "TLSv1.3"
        ]
        if len(tls_13_ciphers):
            assert len(tls_13_ciphers) == 3
예제 #20
0
def test_repr():
    with new_config({}):
        sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem")
        assert (
            repr(sec) ==
            "Security(tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')")
예제 #21
0
def main(scheduler, host, worker_port, listen_address, contact_address,
         nanny_port, nthreads, nprocs, nanny, name, pid_file, resources,
         dashboard, bokeh, bokeh_port, scheduler_file, dashboard_prefix,
         tls_ca_file, tls_cert, tls_key, dashboard_address, **kwargs):
    g0, g1, g2 = gc.get_threshold(
    )  # https://github.com/dask/distributed/issues/1653
    gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)

    enable_proctitle_on_current()
    enable_proctitle_on_children()

    if bokeh_port is not None:
        warnings.warn(
            "The --bokeh-port flag has been renamed to --dashboard-address. "
            "Consider adding ``--dashboard-address :%d`` " % bokeh_port)
        dashboard_address = bokeh_port
    if bokeh is not None:
        warnings.warn(
            "The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. "
        )
        dashboard = bokeh

    sec = Security(
        **{
            k: v
            for k, v in [
                ("tls_ca_file", tls_ca_file),
                ("tls_worker_cert", tls_cert),
                ("tls_worker_key", tls_key),
            ] if v is not None
        })

    if nprocs > 1 and worker_port != 0:
        logger.error(
            "Failed to launch worker.  You cannot use the --port argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and not nanny:
        logger.error(
            "Failed to launch worker.  You cannot use the --no-nanny argument when nprocs > 1."
        )
        exit(1)

    if contact_address and not listen_address:
        logger.error(
            "Failed to launch worker. "
            "Must specify --listen-address when --contact-address is given")
        exit(1)

    if nprocs > 1 and listen_address:
        logger.error("Failed to launch worker. "
                     "You cannot specify --listen-address when nprocs > 1.")
        exit(1)

    if (worker_port or host) and listen_address:
        logger.error(
            "Failed to launch worker. "
            "You cannot specify --listen-address when --worker-port or --host is given."
        )
        exit(1)

    try:
        if listen_address:
            (host, worker_port) = get_address_host_port(listen_address,
                                                        strict=True)

        if contact_address:
            # we only need this to verify it is getting parsed
            (_, _) = get_address_host_port(contact_address, strict=True)
        else:
            # if contact address is not present we use the listen_address for contact
            contact_address = listen_address
    except ValueError as e:
        logger.error("Failed to launch worker. " + str(e))
        exit(1)

    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if not nthreads:
        nthreads = CPU_COUNT // nprocs

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    if resources:
        resources = resources.replace(",", " ").split()
        resources = dict(pair.split("=") for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    if nanny:
        kwargs.update({
            "worker_port": worker_port,
            "listen_address": listen_address
        })
        t = Nanny
    else:
        if nanny_port:
            kwargs["service_ports"] = {"nanny": nanny_port}
        t = Worker

    if (not scheduler and not scheduler_file
            and dask.config.get("scheduler-address", None) is None):
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    with ignoring(TypeError, ValueError):
        name = int(name)

    nannies = [
        t(scheduler,
          scheduler_file=scheduler_file,
          nthreads=nthreads,
          loop=loop,
          resources=resources,
          security=sec,
          contact_address=contact_address,
          host=host,
          port=port,
          dashboard_address=dashboard_address if dashboard else None,
          service_kwargs={"dashboard": {
              "prefix": dashboard_prefix
          }},
          name=name if nprocs == 1 or name is None or name == "" else
          str(name) + "-" + str(i),
          **kwargs) for i in range(nprocs)
    ]

    @gen.coroutine
    def close_all():
        # Unregister all workers from scheduler
        if nanny:
            yield [n.close(timeout=2) for n in nannies]

    def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        close_all()

    @gen.coroutine
    def run():
        yield nannies
        yield [n.finished() for n in nannies]

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except TimeoutError:
        # We already log the exception in nanny / worker. Don't do it again.
        raise TimeoutError("Timed out starting worker.") from None
    except KeyboardInterrupt:
        pass
    finally:
        logger.info("End worker")
예제 #22
0
def test_repr_local_keys():
    sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem")
    representation = repr(sec)
    assert "ca.pem" in representation
    assert "scert.pem" in representation
예제 #23
0
        MET = df['MET_pt']
        # We can define a new key for cutflow (in this case 'all events'). Then we can put values into it. We need += because it's per-chunk (demonstrated below)
        output['cutflow']['all events'] += MET.size
        output['cutflow']['number of chunks'] += 1
        # This fills our histogram once our data is collected. Always use .flatten() to make sure the array is reduced. The output key will be as defined in __init__ for self._accumulator; the hist key ('MET=') will be defined in the bin.
        output['MET'].fill(dataset=dataset, MET=MET.flatten())
        return output

    def postprocess(self, accumulator):
        return accumulator


sec_dask = Security(tls_ca_file='/etc/cmsaf-secrets/ca.pem',
                    tls_worker_cert='/etc/cmsaf-secrets/usercert.pem',
                    tls_worker_key='/etc/cmsaf-secrets/usercert.pem',
                    tls_client_cert='/etc/cmsaf-secrets/usercert.pem',
                    tls_client_key='/etc/cmsaf-secrets/usercert.pem',
                    tls_scheduler_cert='/etc/cmsaf-secrets/hostcert.pem',
                    tls_scheduler_key='/etc/cmsaf-secrets/hostcert.pem',
                    require_encryption=True)

HTCondorJob.submit_command = "condor_submit -spool"

cluster = HTCondorCluster(
    cores=4,
    memory="2GB",
    disk="1GB",
    log_directory="logs",
    silence_logs="debug",
    scheduler_options={
        "dashboard_address": "8786",
        "port": 8787,
예제 #24
0
async def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    async def handle_comm(comm):
        comm.abort()

    c = {
        "distributed.comm.tls.ca-file": ca_file,
        "distributed.comm.tls.scheduler.key": key1,
        "distributed.comm.tls.scheduler.cert": cert1,
        "distributed.comm.tls.worker.cert": keycert1,
    }
    with dask.config.set(c):
        sec = Security()

    c["distributed.comm.require-encryption"] = True
    with dask.config.set(c):
        sec2 = Security()

    for listen_addr in ["inproc://", "tls://"]:
        async with listen(listen_addr, handle_comm,
                          **sec.get_listen_args("scheduler")) as listener:
            comm = await connect(listener.contact_address,
                                 **sec2.get_connection_args("worker"))
            comm.abort()

        async with listen(listen_addr, handle_comm,
                          **sec2.get_listen_args("scheduler")) as listener:
            comm = await connect(listener.contact_address,
                                 **sec2.get_connection_args("worker"))
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ["tcp://"]:
        async with listen(listen_addr, handle_comm,
                          **sec.get_listen_args("scheduler")) as listener:
            comm = await connect(listener.contact_address,
                                 **sec.get_connection_args("worker"))
            comm.abort()

            with pytest.raises(RuntimeError):
                await connect(listener.contact_address,
                              **sec2.get_connection_args("worker"))

        with pytest.raises(RuntimeError):
            listen(listen_addr, handle_comm,
                   **sec2.get_listen_args("scheduler"))
예제 #25
0
    def __init__(
            self,
            n_workers=1,
            # Cluster keywords
            loop=None,
            security=None,
            silence_logs="error",
            name=None,
            # Scheduler-only keywords
            scheduler_options=None,
            # Options for both scheduler and workers
            interface=None,
            protocol="tcp://",
            # Job keywords
            config_name=None,
            # Some SLURM only keywords
            scheduler_file: str = None,
            local_scheduler: bool = False,
            timeout: int = 500,
            **kwargs):
        self.status = "created"
        self.name = name

        if scheduler_file is None:
            scheduler_file = "~/scheduler.json"

        scheduler_file = os.path.expanduser(scheduler_file)
        if os.path.exists(scheduler_file):
            raise ValueError("A scheduler file was already found at " +
                             scheduler_file)
        # if interface is None:
        #     interface = dask.config.get("jobqueue.%s.interface" % config_name)
        if scheduler_options is None:
            scheduler_options = {}

        default_scheduler_options = {
            "protocol": protocol,
            "dashboard_address": ":8787",
            "security": security,
        }

        if local_scheduler:
            scheduler_options = dict(default_scheduler_options,
                                     **scheduler_options)
            self.scheduler_spec = {
                "cls": Scheduler,  # Use local scheduler for now
                "options": scheduler_options,
            }
        self.scheduler_file = scheduler_file

        kwargs["config_name"] = config_name
        kwargs["interface"] = interface

        self.protocol = protocol
        self.security = security or Security()
        self._kwargs = kwargs
        self.worker = None
        self.local = local_scheduler

        default_worker_options = {
            "cores": 1,
            "memory": "1GB",
            "python": shutil.which("python"),
            "dask_path": shutil.which("dask-mpi"),
            "threads_per_cpu": 4,
            "time": None,
            "constraint": None,
            "qos": None,
            "memory_limit": "0.5"
        }

        self.n_workers = 0

        self.worker_fields = default_worker_options.keys()
        self.worker_options = default_worker_options
        if kwargs is not None:
            self.worker_options.update({
                key: value
                for key, value in kwargs.items() if value is not None
            })
        print(kwargs)
        self._loop_runner = LoopRunner(asynchronous=True)
        self._loop = self._loop_runner.loop

        self._lock = asyncio.Lock()
        self._loop_runner.start()
        self._supports_scaling = False

        self.worker_count = int(n_workers)
        self.nodes = int(self.worker_options["cores"])
        self.memory = self.worker_options["memory"]
        self.threads_per_cpu = int(self.worker_options["threads_per_cpu"])
        self.time = self.worker_options["time"]
        self.constraint = self.worker_options["constraint"]
        self.qos = self.worker_options["qos"]
        self.dask_path = self.worker_options["dask_path"]
        self.memory_limit = self.worker_options["memory_limit"]

        # A really dirty hack to ensure that mpirun is in the desired path
        self.mpirun_path = os.path.join(os.path.dirname(self.dask_path),
                                        "mpirun")

        self.timeout = int(timeout)
        print(self.dask_path)
        if self.worker_options["python"] is None:
            raise ValueError(
                "Python path was unspecified and was not found in the environment path"
            )
        if self.dask_path is None:
            raise ValueError(
                "Dask MPI path was unspecified and was not found in the environment path"
            )

        super().__init__(asynchronous=True)
예제 #26
0
    def __init__(
            self,
            n_workers=0,
            job_cls: Job = None,
            # Cluster keywords
            loop=None,
            security=None,
            shared_temp_directory=None,
            silence_logs="error",
            name=None,
            asynchronous=False,
            # Scheduler-only keywords
            dashboard_address=None,
            host=None,
            scheduler_options=None,
            scheduler_cls=Scheduler,  # Use local scheduler for now
            # Options for both scheduler and workers
        interface=None,
            protocol=None,
            # Job keywords
            config_name=None,
            **job_kwargs):
        self.status = Status.created

        default_job_cls = getattr(type(self), "job_cls", None)
        self.job_cls = default_job_cls
        if job_cls is not None:
            self.job_cls = job_cls

        if self.job_cls is None:
            raise ValueError(
                "You need to specify a Job type. Two cases:\n"
                "- you are inheriting from JobQueueCluster (most likely): you need to add a 'job_cls' class variable "
                "in your JobQueueCluster-derived class {}\n"
                "- you are using JobQueueCluster directly (less likely, only useful for tests): "
                "please explicitly pass a Job type through the 'job_cls' parameter."
                .format(type(self)))

        if dashboard_address is not None:
            raise ValueError(
                "Please pass 'dashboard_address' through 'scheduler_options': use\n"
                'cluster = {0}(..., scheduler_options={{"dashboard_address": ":12345"}}) rather than\n'
                'cluster = {0}(..., dashboard_address="12435")'.format(
                    self.__class__.__name__))

        if host is not None:
            raise ValueError(
                "Please pass 'host' through 'scheduler_options': use\n"
                'cluster = {0}(..., scheduler_options={{"host": "your-host"}}) rather than\n'
                'cluster = {0}(..., host="your-host")'.format(
                    self.__class__.__name__))

        default_config_name = self.job_cls.default_config_name()
        if config_name is None:
            config_name = default_config_name

        if interface is None:
            interface = dask.config.get("jobqueue.%s.interface" % config_name)
        if scheduler_options is None:
            scheduler_options = dask.config.get(
                "jobqueue.%s.scheduler-options" % config_name, {})

        if protocol is None and security is not None:
            protocol = "tls://"

        if security is True:
            try:
                security = Security.temporary()
            except ImportError:
                raise ImportError(
                    "In order to use TLS without pregenerated certificates `cryptography` is required,"
                    "please install it using either pip or conda")

        default_scheduler_options = {
            "protocol": protocol,
            "dashboard_address": ":8787",
            "security": security,
        }
        # scheduler_options overrides parameters common to both workers and scheduler
        scheduler_options = dict(default_scheduler_options,
                                 **scheduler_options)

        # Use the same network interface as the workers if scheduler ip has not
        # been set through scheduler_options via 'host' or 'interface'
        if "host" not in scheduler_options and "interface" not in scheduler_options:
            scheduler_options["interface"] = interface

        scheduler = {
            "cls": scheduler_cls,
            "options": scheduler_options,
        }

        if shared_temp_directory is None:
            shared_temp_directory = dask.config.get(
                "jobqueue.%s.shared-temp-directory" % config_name)
        self.shared_temp_directory = shared_temp_directory

        job_kwargs["config_name"] = config_name
        job_kwargs["interface"] = interface
        job_kwargs["protocol"] = protocol
        job_kwargs["security"] = self._get_worker_security(security)

        self._job_kwargs = job_kwargs

        worker = {"cls": self.job_cls, "options": self._job_kwargs}
        if "processes" in self._job_kwargs and self._job_kwargs[
                "processes"] > 1:
            worker["group"] = [
                "-" + str(i) for i in range(self._job_kwargs["processes"])
            ]

        self._dummy_job  # trigger property to ensure that the job is valid

        super().__init__(
            scheduler=scheduler,
            worker=worker,
            loop=loop,
            security=security,
            silence_logs=silence_logs,
            asynchronous=asynchronous,
            name=name,
        )

        if n_workers:
            self.scale(n_workers)
예제 #27
0
    def __init__(
        self,
        scheduler_ip=None,
        scheduler_port=None,
        scheduler_file=None,
        worker_port=0,
        nthreads=None,
        loop=None,
        local_dir=None,
        local_directory=None,
        services=None,
        name=None,
        memory_limit="auto",
        reconnect=True,
        validate=False,
        quiet=False,
        resources=None,
        silence_logs=None,
        death_timeout=None,
        preload=None,
        preload_argv=None,
        preload_nanny=None,
        preload_nanny_argv=None,
        security=None,
        contact_address=None,
        listen_address=None,
        worker_class=None,
        env=None,
        interface=None,
        host=None,
        port=None,
        protocol=None,
        config=None,
        **worker_kwargs,
    ):
        self._setup_logging(logger)
        self.loop = loop or IOLoop.current()

        if isinstance(security, dict):
            security = Security(**security)
        self.security = security or Security()
        assert isinstance(self.security, Security)
        self.connection_args = self.security.get_connection_args("worker")

        if local_dir is not None:
            warnings.warn("The local_dir keyword has moved to local_directory")
            local_directory = local_dir

        if local_directory is None:
            local_directory = dask.config.get(
                "temporary-directory") or os.getcwd()
            self._original_local_dir = local_directory
            local_directory = os.path.join(local_directory,
                                           "dask-worker-space")
        else:
            self._original_local_dir = local_directory

        self.local_directory = local_directory
        if not os.path.exists(self.local_directory):
            os.makedirs(self.local_directory, exist_ok=True)

        self.preload = preload
        if self.preload is None:
            self.preload = dask.config.get("distributed.worker.preload")
        self.preload_argv = preload_argv
        if self.preload_argv is None:
            self.preload_argv = dask.config.get(
                "distributed.worker.preload-argv")

        if preload_nanny is None:
            preload_nanny = dask.config.get("distributed.nanny.preload")
        if preload_nanny_argv is None:
            preload_nanny_argv = dask.config.get(
                "distributed.nanny.preload-argv")

        self.preloads = preloading.process_preloads(
            self,
            preload_nanny,
            preload_nanny_argv,
            file_dir=self.local_directory)

        if scheduler_file:
            cfg = json_load_robust(scheduler_file)
            self.scheduler_addr = cfg["address"]
        elif scheduler_ip is None and dask.config.get("scheduler-address"):
            self.scheduler_addr = dask.config.get("scheduler-address")
        elif scheduler_port is None:
            self.scheduler_addr = coerce_to_address(scheduler_ip)
        else:
            self.scheduler_addr = coerce_to_address(
                (scheduler_ip, scheduler_port))

        if protocol is None:
            protocol_address = self.scheduler_addr.split("://")
            if len(protocol_address) == 2:
                protocol = protocol_address[0]

        self._given_worker_port = worker_port
        self.nthreads = nthreads or CPU_COUNT
        self.reconnect = reconnect
        self.validate = validate
        self.resources = resources
        self.death_timeout = parse_timedelta(death_timeout)

        self.Worker = Worker if worker_class is None else worker_class
        config_environ = dask.config.get("distributed.nanny.environ", {})
        if not isinstance(config_environ, dict):
            raise TypeError(
                "distributed.nanny.environ configuration must be of type dict. "
                f"Instead got {type(config_environ)}")
        self.env = config_environ.copy()
        for k in self.env:
            if k in os.environ:
                self.env[k] = os.environ[k]
        if env:
            self.env.update(env)
        self.env = {k: str(v) for k, v in self.env.items()}
        self.config = config or dask.config.config
        worker_kwargs.update({
            "port": worker_port,
            "interface": interface,
            "protocol": protocol,
            "host": host,
        })
        self.worker_kwargs = worker_kwargs

        self.contact_address = contact_address

        self.services = services
        self.name = name
        self.quiet = quiet
        self.auto_restart = True

        if silence_logs:
            silence_logging(level=silence_logs)
        self.silence_logs = silence_logs

        handlers = {
            "instantiate": self.instantiate,
            "kill": self.kill,
            "restart": self.restart,
            # cannot call it 'close' on the rpc side for naming conflict
            "get_logs": self.get_logs,
            "terminate": self.close,
            "close_gracefully": self.close_gracefully,
            "run": self.run,
            "plugin_add": self.plugin_add,
            "plugin_remove": self.plugin_remove,
        }

        self.plugins: dict[str, NannyPlugin] = {}

        super().__init__(handlers=handlers,
                         io_loop=self.loop,
                         connection_args=self.connection_args)

        self.scheduler = self.rpc(self.scheduler_addr)
        self.memory_manager = NannyMemoryManager(self,
                                                 memory_limit=memory_limit)

        if (not host and not interface
                and not self.scheduler_addr.startswith("inproc://")):
            host = get_ip(get_address_host(self.scheduler.address))

        self._start_port = port
        self._start_host = host
        self._interface = interface
        self._protocol = protocol

        self._listen_address = listen_address
        Nanny._instances.add(self)
        self.status = Status.init
예제 #28
0
def main(
    scheduler,
    host,
    nthreads,
    name,
    memory_limit,
    device_memory_limit,
    rmm_pool_size,
    pid_file,
    resources,
    dashboard,
    dashboard_address,
    local_directory,
    scheduler_file,
    interface,
    death_timeout,
    preload,
    dashboard_prefix,
    tls_ca_file,
    tls_cert,
    tls_key,
    enable_tcp_over_ucx,
    enable_infiniband,
    enable_nvlink,
    enable_rdmacm,
    net_devices,
    **kwargs,
):
    enable_proctitle_on_current()
    enable_proctitle_on_children()

    if tls_ca_file and tls_cert and tls_key:
        sec = Security(
            tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key
        )
    else:
        sec = None

    try:
        nprocs = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
    except KeyError:
        nprocs = get_n_gpus()

    if not nthreads:
        nthreads = min(1, multiprocessing.cpu_count() // nprocs)

    memory_limit = parse_memory_limit(memory_limit, nthreads, total_cores=nprocs)

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    services = {}

    if dashboard:
        try:
            from distributed.dashboard import BokehWorker
        except ImportError:
            pass
        else:
            if dashboard_prefix:
                result = (BokehWorker, {"prefix": dashboard_prefix})
            else:
                result = BokehWorker
            services[("dashboard", dashboard_address)] = result

    if resources:
        resources = resources.replace(",", " ").split()
        resources = dict(pair.split("=") for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    preload_argv = kwargs.get("preload_argv", [])
    kwargs = {"worker_port": None, "listen_address": None}
    t = Nanny

    if not scheduler and not scheduler_file and "scheduler-address" not in config:
        raise ValueError(
            "Need to provide scheduler address like\n"
            "dask-worker SCHEDULER_ADDRESS:8786"
        )

    if interface:
        if host:
            raise ValueError("Can not specify both interface and host")
        else:
            host = get_ip_interface(interface)

    if rmm_pool_size is not None:
        try:
            import rmm  # noqa F401
        except ImportError:
            raise ValueError(
                "RMM pool requested but module 'rmm' is not available. "
                "For installation instructions, please see "
                "https://github.com/rapidsai/rmm"
            )  # pragma: no cover
        rmm_pool_size = parse_bytes(rmm_pool_size)

    nannies = [
        t(
            scheduler,
            scheduler_file=scheduler_file,
            nthreads=nthreads,
            services=services,
            loop=loop,
            resources=resources,
            memory_limit=memory_limit,
            interface=get_ucx_net_devices(
                cuda_device_index=i,
                ucx_net_devices=net_devices,
                get_openfabrics=False,
                get_network=True,
            ),
            preload=(list(preload) or []) + ["dask_cuda.initialize"],
            preload_argv=(list(preload_argv) or []) + ["--create-cuda-context"],
            security=sec,
            env={"CUDA_VISIBLE_DEVICES": cuda_visible_devices(i)},
            plugins={CPUAffinity(get_cpu_affinity(i)), RMMPool(rmm_pool_size)},
            name=name if nprocs == 1 or not name else name + "-" + str(i),
            local_directory=local_directory,
            config={
                "ucx": get_ucx_config(
                    enable_tcp_over_ucx=enable_tcp_over_ucx,
                    enable_infiniband=enable_infiniband,
                    enable_nvlink=enable_nvlink,
                    enable_rdmacm=enable_rdmacm,
                    net_devices=net_devices,
                    cuda_device_index=i,
                )
            },
            data=(
                DeviceHostFile,
                {
                    "device_memory_limit": get_device_total_memory(index=i)
                    if (device_memory_limit == "auto" or device_memory_limit == int(0))
                    else parse_bytes(device_memory_limit),
                    "memory_limit": memory_limit,
                    "local_directory": local_directory,
                },
            ),
            **kwargs,
        )
        for i in range(nprocs)
    ]

    @gen.coroutine
    def close_all():
        # Unregister all workers from scheduler
        yield [n._close(timeout=2) for n in nannies]

    def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        close_all()

    @gen.coroutine
    def run():
        yield nannies
        yield [n.finished() for n in nannies]

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")
예제 #29
0
def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    @gen.coroutine
    def handle_comm(comm):
        comm.abort()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()
    c['require-encryption'] = True
    with new_config(c):
        sec2 = Security()

    for listen_addr in ['inproc://', 'tls://']:
        with listen(
                listen_addr,
                handle_comm,
                connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(
                listener.contact_address,
                connection_args=sec2.get_connection_args('worker'))
            comm.abort()

        with listen(
                listen_addr,
                handle_comm,
                connection_args=sec2.get_listen_args('scheduler')) as listener:
            comm = yield connect(
                listener.contact_address,
                connection_args=sec2.get_connection_args('worker'))
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ['tcp://']:
        with listen(
                listen_addr,
                handle_comm,
                connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(
                listener.contact_address,
                connection_args=sec.get_connection_args('worker'))
            comm.abort()

            with pytest.raises(RuntimeError):
                yield connect(
                    listener.contact_address,
                    connection_args=sec2.get_connection_args('worker'))

        with pytest.raises(RuntimeError):
            listen(listen_addr,
                   handle_comm,
                   connection_args=sec2.get_listen_args('scheduler'))
예제 #30
0
def main(scheduler, host, worker_port, http_port, nanny_port, nthreads, nprocs,
         nanny, name, memory_limit, pid_file, reconnect, resources, bokeh,
         bokeh_port, local_directory, scheduler_file, interface, death_timeout,
         preload, bokeh_prefix, tls_ca_file, tls_cert, tls_key):
    sec = Security(
        tls_ca_file=tls_ca_file,
        tls_worker_cert=tls_cert,
        tls_worker_key=tls_key,
    )

    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if nprocs > 1 and worker_port != 0:
        logger.error(
            "Failed to launch worker.  You cannot use the --port argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and name:
        logger.error(
            "Failed to launch worker.  You cannot use the --name argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and not nanny:
        logger.error(
            "Failed to launch worker.  You cannot use the --no-nanny argument when nprocs > 1."
        )
        exit(1)

    if not nthreads:
        nthreads = _ncores // nprocs

    if pid_file:
        with open(pid_file, 'w') as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    services = {('http', http_port): HTTPWorker}

    if bokeh:
        try:
            from distributed.bokeh.worker import BokehWorker
        except ImportError:
            pass
        else:
            if bokeh_prefix:
                result = (BokehWorker, {'prefix': bokeh_prefix})
            else:
                result = BokehWorker
            services[('bokeh', bokeh_port)] = result

    if resources:
        resources = resources.replace(',', ' ').split()
        resources = dict(pair.split('=') for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    if nanny:
        kwargs = {'worker_port': worker_port}
        t = Nanny
    else:
        kwargs = {}
        if nanny_port:
            kwargs['service_ports'] = {'nanny': nanny_port}
        t = Worker

    if scheduler_file:
        while not os.path.exists(scheduler_file):
            sleep(0.01)
        for i in range(10):
            try:
                with open(scheduler_file) as f:
                    cfg = json.load(f)
                scheduler = cfg['address']
                break
            except (ValueError, KeyError):  # race with scheduler on file
                sleep(0.01)

    if not scheduler:
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    if interface:
        if host:
            raise ValueError("Can not specify both interface and host")
        else:
            host = get_ip_interface(interface)

    if host or port:
        addr = uri_from_host_port(host, port, 0)
    else:
        # Choose appropriate address for scheduler
        addr = None

    nannies = [
        t(scheduler,
          ncores=nthreads,
          services=services,
          name=name,
          loop=loop,
          resources=resources,
          memory_limit=memory_limit,
          reconnect=reconnect,
          local_dir=local_directory,
          death_timeout=death_timeout,
          preload=preload,
          security=sec,
          **kwargs) for i in range(nprocs)
    ]

    @gen.coroutine
    def close_all():
        try:
            if nanny:
                yield [n._close(timeout=2) for n in nannies]
        finally:
            loop.stop()

    def handle_signal(signum, frame):
        logger.info("Exiting on signal %d", signum)
        if loop._running:
            loop.add_callback_from_signal(loop.stop)
        else:
            exit(0)

    # NOTE: We can't use the generic install_signal_handlers() function from
    # distributed.cli.utils because we're handling the signal differently.
    signal.signal(signal.SIGINT, handle_signal)
    signal.signal(signal.SIGTERM, handle_signal)

    for n in nannies:
        n.start(addr)

    @gen.coroutine
    def run():
        while all(n.status != 'closed' for n in nannies):
            yield gen.sleep(0.2)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")

    # Clean exit: unregister all workers from scheduler
    loop.run_sync(close_all)
예제 #31
0
import pytest

import dask

from distributed import Client, Scheduler, Worker
from distributed.comm import connect, listen, ws
from distributed.comm.core import FatalCommClosedError
from distributed.comm.registry import backends, get_backend
from distributed.security import Security
from distributed.utils_test import (  # noqa: F401
    cleanup, gen_cluster, get_client_ssl_context, get_server_ssl_context, inc,
)

from .test_comms import check_tls_extra

security = Security.temporary()


def test_registered():
    assert "ws" in backends
    backend = get_backend("ws")
    assert isinstance(backend, ws.WSBackend)


@pytest.mark.asyncio
async def test_listen_connect(cleanup):
    async def handle_comm(comm):
        while True:
            msg = await comm.read()
            await comm.write(msg)
예제 #32
0
def main(
    scheduler,
    host,
    worker_port,
    listen_address,
    contact_address,
    nanny_port,
    nthreads,
    nprocs,
    nanny,
    name,
    memory_limit,
    pid_file,
    reconnect,
    resources,
    dashboard,
    bokeh_port,
    local_directory,
    scheduler_file,
    interface,
    protocol,
    death_timeout,
    preload,
    preload_argv,
    dashboard_prefix,
    tls_ca_file,
    tls_cert,
    tls_key,
    dashboard_address,
):
    g0, g1, g2 = gc.get_threshold(
    )  # https://github.com/dask/distributed/issues/1653
    gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)

    enable_proctitle_on_current()
    enable_proctitle_on_children()

    if bokeh_port is not None:
        warnings.warn(
            "The --bokeh-port flag has been renamed to --dashboard-address. "
            "Consider adding ``--dashboard-address :%d`` " % bokeh_port)
        dashboard_address = bokeh_port

    sec = Security(tls_ca_file=tls_ca_file,
                   tls_worker_cert=tls_cert,
                   tls_worker_key=tls_key)

    if nprocs > 1 and worker_port != 0:
        logger.error(
            "Failed to launch worker.  You cannot use the --port argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and not nanny:
        logger.error(
            "Failed to launch worker.  You cannot use the --no-nanny argument when nprocs > 1."
        )
        exit(1)

    if contact_address and not listen_address:
        logger.error(
            "Failed to launch worker. "
            "Must specify --listen-address when --contact-address is given")
        exit(1)

    if nprocs > 1 and listen_address:
        logger.error("Failed to launch worker. "
                     "You cannot specify --listen-address when nprocs > 1.")
        exit(1)

    if (worker_port or host) and listen_address:
        logger.error(
            "Failed to launch worker. "
            "You cannot specify --listen-address when --worker-port or --host is given."
        )
        exit(1)

    try:
        if listen_address:
            (host, worker_port) = get_address_host_port(listen_address,
                                                        strict=True)

        if contact_address:
            # we only need this to verify it is getting parsed
            (_, _) = get_address_host_port(contact_address, strict=True)
        else:
            # if contact address is not present we use the listen_address for contact
            contact_address = listen_address
    except ValueError as e:
        logger.error("Failed to launch worker. " + str(e))
        exit(1)

    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if not nthreads:
        nthreads = _ncores // nprocs

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    services = {}

    if resources:
        resources = resources.replace(",", " ").split()
        resources = dict(pair.split("=") for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    if nanny:
        kwargs = {"worker_port": worker_port, "listen_address": listen_address}
        t = Nanny
    else:
        kwargs = {}
        if nanny_port:
            kwargs["service_ports"] = {"nanny": nanny_port}
        t = Worker

    if (not scheduler and not scheduler_file
            and dask.config.get("scheduler-address", None) is None):
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    if death_timeout is not None:
        death_timeout = parse_timedelta(death_timeout, "s")

    nannies = [
        t(scheduler,
          scheduler_file=scheduler_file,
          ncores=nthreads,
          services=services,
          loop=loop,
          resources=resources,
          memory_limit=memory_limit,
          reconnect=reconnect,
          local_dir=local_directory,
          death_timeout=death_timeout,
          preload=preload,
          preload_argv=preload_argv,
          security=sec,
          contact_address=contact_address,
          interface=interface,
          protocol=protocol,
          host=host,
          port=port,
          dashboard_address=dashboard_address if dashboard else None,
          service_kwargs={"bokhe": {
              "prefix": dashboard_prefix
          }},
          name=name if nprocs == 1 or not name else name + "-" + str(i),
          **kwargs) for i in range(nprocs)
    ]

    @gen.coroutine
    def close_all():
        # Unregister all workers from scheduler
        if nanny:
            yield [n.close(timeout=2) for n in nannies]

    def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        close_all()

    @gen.coroutine
    def run():
        yield nannies
        while all(n.status != "closed" for n in nannies):
            yield gen.sleep(0.2)

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")
예제 #33
0
def test_repr():
    sec = Security(tls_ca_file="ca.pem", tls_scheduler_cert="scert.pem")
    assert (
        repr(sec)
        == "Security(require_encryption=False, tls_ca_file='ca.pem', tls_scheduler_cert='scert.pem')"
    )
예제 #34
0
def main(scheduler, host, worker_port, listen_address, contact_address,
         nanny_port, nthreads, nprocs, nanny, name, memory_limit, pid_file,
         reconnect, resources, bokeh, bokeh_port, local_directory,
         scheduler_file, interface, death_timeout, preload, bokeh_prefix,
         tls_ca_file, tls_cert, tls_key):
    sec = Security(
        tls_ca_file=tls_ca_file,
        tls_worker_cert=tls_cert,
        tls_worker_key=tls_key,
    )

    if nprocs > 1 and worker_port != 0:
        logger.error(
            "Failed to launch worker.  You cannot use the --port argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and name:
        logger.error(
            "Failed to launch worker.  You cannot use the --name argument when nprocs > 1."
        )
        exit(1)

    if nprocs > 1 and not nanny:
        logger.error(
            "Failed to launch worker.  You cannot use the --no-nanny argument when nprocs > 1."
        )
        exit(1)

    if contact_address and not listen_address:
        logger.error(
            "Failed to launch worker. "
            "Must specify --listen-address when --contact-address is given")
        exit(1)

    if nprocs > 1 and listen_address:
        logger.error("Failed to launch worker. "
                     "You cannot specify --listen-address when nprocs > 1.")
        exit(1)

    if (worker_port or host) and listen_address:
        logger.error(
            "Failed to launch worker. "
            "You cannot specify --listen-address when --worker-port or --host is given."
        )
        exit(1)

    try:
        if listen_address:
            (host, worker_port) = get_address_host_port(listen_address,
                                                        strict=True)

        if contact_address:
            # we only need this to verify it is getting parsed
            (_, _) = get_address_host_port(contact_address, strict=True)
        else:
            # if contact address is not present we use the listen_address for contact
            contact_address = listen_address
    except ValueError as e:
        logger.error("Failed to launch worker. " + str(e))
        exit(1)

    if nanny:
        port = nanny_port
    else:
        port = worker_port

    if not nthreads:
        nthreads = _ncores // nprocs

    if pid_file:
        with open(pid_file, 'w') as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    services = {}

    if bokeh:
        try:
            from distributed.bokeh.worker import BokehWorker
        except ImportError:
            pass
        else:
            if bokeh_prefix:
                result = (BokehWorker, {'prefix': bokeh_prefix})
            else:
                result = BokehWorker
            services[('bokeh', bokeh_port)] = result

    if resources:
        resources = resources.replace(',', ' ').split()
        resources = dict(pair.split('=') for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    if nanny:
        kwargs = {'worker_port': worker_port, 'listen_address': listen_address}
        t = Nanny
    else:
        kwargs = {}
        if nanny_port:
            kwargs['service_ports'] = {'nanny': nanny_port}
        t = Worker

    if not scheduler and not scheduler_file:
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    if interface:
        if host:
            raise ValueError("Can not specify both interface and host")
        else:
            host = get_ip_interface(interface)

    if host or port:
        addr = uri_from_host_port(host, port, 0)
    else:
        # Choose appropriate address for scheduler
        addr = None

    nannies = [
        t(scheduler,
          scheduler_file=scheduler_file,
          ncores=nthreads,
          services=services,
          name=name,
          loop=loop,
          resources=resources,
          memory_limit=memory_limit,
          reconnect=reconnect,
          local_dir=local_directory,
          death_timeout=death_timeout,
          preload=preload,
          security=sec,
          contact_address=contact_address,
          **kwargs) for i in range(nprocs)
    ]

    @gen.coroutine
    def close_all():
        # Unregister all workers from scheduler
        if nanny:
            yield [n._close(timeout=2) for n in nannies]

    def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        close_all()

    @gen.coroutine
    def run():
        yield [n.start(addr) for n in nannies]
        while all(n.status != 'closed' for n in nannies):
            yield gen.sleep(0.2)

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")
예제 #35
0
def main(
    scheduler,
    host,
    nthreads,
    name,
    memory_limit,
    device_memory_limit,
    rmm_pool_size,
    rmm_managed_memory,
    pid_file,
    resources,
    dashboard,
    dashboard_address,
    local_directory,
    scheduler_file,
    interface,
    death_timeout,
    preload,
    dashboard_prefix,
    tls_ca_file,
    tls_cert,
    tls_key,
    enable_tcp_over_ucx,
    enable_infiniband,
    enable_nvlink,
    enable_rdmacm,
    net_devices,
    enable_jit_unspill,
    **kwargs,
):
    if tls_ca_file and tls_cert and tls_key:
        security = Security(
            tls_ca_file=tls_ca_file, tls_worker_cert=tls_cert, tls_worker_key=tls_key,
        )
    else:
        security = None

    if isinstance(scheduler, str) and scheduler.startswith("-"):
        raise ValueError(
            "The scheduler address can't start with '-'. Please check "
            "your command line arguments, you probably attempted to use "
            "unsupported one. Scheduler address: %s" % scheduler
        )

    worker = CUDAWorker(
        scheduler,
        host,
        nthreads,
        name,
        memory_limit,
        device_memory_limit,
        rmm_pool_size,
        rmm_managed_memory,
        pid_file,
        resources,
        dashboard,
        dashboard_address,
        local_directory,
        scheduler_file,
        interface,
        death_timeout,
        preload,
        dashboard_prefix,
        security,
        enable_tcp_over_ucx,
        enable_infiniband,
        enable_nvlink,
        enable_rdmacm,
        net_devices,
        enable_jit_unspill,
        **kwargs,
    )

    async def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        await worker.close()

    async def run():
        await worker
        await worker.finished()

    loop = IOLoop.current()

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")