コード例 #1
0
def test_security_temporary(EnvSpecificCluster, loop):
    dirname = os.path.dirname(__file__)
    with EnvSpecificCluster(
            cores=1,
            memory="100MB",
            security=Security.temporary(),
            shared_temp_directory=dirname,
            protocol="tls",
            loop=loop,
    ) as cluster:
        assert cluster.security
        assert cluster.scheduler_spec["options"][
            "security"] == cluster.security
        job_script = cluster.job_script()
        assert "tls://" in job_script
        keyfile = re.findall(r"--tls-key (\S+)", job_script)[0]
        assert (
            os.path.exists(keyfile) and
            os.path.basename(keyfile).startswith(".dask-jobqueue.worker.key")
            and os.path.dirname(keyfile) == dirname)
        certfile = re.findall(r"--tls-cert (\S+)", job_script)[0]
        assert (
            os.path.exists(certfile) and
            os.path.basename(certfile).startswith(".dask-jobqueue.worker.cert")
            and os.path.dirname(certfile) == dirname)
        cafile = re.findall(r"--tls-ca-file (\S+)", job_script)[0]
        assert (os.path.exists(cafile) and os.path.basename(cafile).startswith(
            ".dask-jobqueue.worker.ca_file")
                and os.path.dirname(cafile) == dirname)

        cluster.scale(jobs=1)
        with Client(cluster) as client:
            future = client.submit(lambda x: x + 1, 10)
            result = future.result(timeout=30)
            assert result == 11
コード例 #2
0
def test_extra_conn_args_in_temporary_credentials():
    pytest.importorskip("cryptography")

    sec = Security.temporary(
        extra_conn_args={"headers": {
            "X-Request-ID": "abcd"
        }})
    assert sec.extra_conn_args == {"headers": {"X-Request-ID": "abcd"}}
コード例 #3
0
def test_temporary_credentials():
    sec = Security.temporary()
    sec_repr = repr(sec)
    fields = ["tls_ca_file"]
    fields.extend("tls_%s_%s" % (role, kind)
                  for role in ["client", "scheduler", "worker"]
                  for kind in ["key", "cert"])
    for f in fields:
        val = getattr(sec, f)
        assert "\n" in val
        assert val not in sec_repr
コード例 #4
0
def test_temporary_credentials():
    pytest.importorskip("cryptography")

    sec = Security.temporary()
    sec_repr = repr(sec)
    fields = ["tls_ca_file"]
    fields.extend(f"tls_{role}_{kind}"
                  for role in ["client", "scheduler", "worker"]
                  for kind in ["key", "cert"])
    for f in fields:
        val = getattr(sec, f)
        assert "\n" in val
        assert val not in sec_repr
コード例 #5
0
async def test_wss_roundtrip():
    np = pytest.importorskip("numpy")
    xfail_ssl_issue5601()
    pytest.importorskip("cryptography")
    security = Security.temporary()
    async with Scheduler(
        protocol="wss://", security=security, dashboard_address=":0"
    ) as s:
        async with Worker(s.address, security=security) as w:
            async with Client(s.address, security=security, asynchronous=True) as c:
                x = np.arange(100)
                future = await c.scatter(x)
                y = await future
                assert (x == y).all()
コード例 #6
0
async def test_http_and_comm_server(cleanup, dashboard, protocol, security, port):
    if security:
        xfail_ssl_issue5601()
        pytest.importorskip("cryptography")
        security = Security.temporary()
    async with Scheduler(
        protocol=protocol, dashboard=dashboard, port=port, security=security
    ) as s:
        if port == 8787:
            assert s.http_server is s.listener.server
        else:
            assert s.http_server is not s.listener.server
        async with Worker(s.address, protocol=protocol, security=security) as w:
            async with Client(s.address, asynchronous=True, security=security) as c:
                result = await c.submit(lambda x: x + 1, 10)
                assert result == 11
コード例 #7
0
async def test_tls_temporary_credentials_functional():
    pytest.importorskip("cryptography")

    async def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith("tls://")
        await comm.write("hello")
        await comm.close()

    sec = Security.temporary()

    async with listen("tls://", handle_comm,
                      **sec.get_listen_args("scheduler")) as listener:
        comm = await connect(listener.contact_address,
                             **sec.get_connection_args("worker"))
        msg = await comm.read()
        assert msg == "hello"
        comm.abort()
コード例 #8
0
ファイル: test_ws.py プロジェクト: haraldschilly/distributed
async def test_expect_scheduler_ssl_when_sharing_server(tmpdir):
    xfail_ssl_issue5601()
    pytest.importorskip("cryptography")
    security = Security.temporary()
    key_path = os.path.join(str(tmpdir), "dask.pem")
    cert_path = os.path.join(str(tmpdir), "dask.crt")
    with open(key_path, "w") as f:
        f.write(security.tls_scheduler_key)
    with open(cert_path, "w") as f:
        f.write(security.tls_scheduler_cert)
    c = {
        "distributed.scheduler.dashboard.tls.key": key_path,
        "distributed.scheduler.dashboard.tls.cert": cert_path,
    }
    with dask.config.set(c):
        with pytest.raises(RuntimeError):
            async with Scheduler(protocol="ws://", dashboard=True, port=8787):
                pass
コード例 #9
0
async def test_connection_made_with_extra_conn_args(cleanup, protocol):
    if protocol == "ws://":
        security = Security(
            extra_conn_args={"headers": {"Authorization": "Token abcd"}}
        )
    else:
        xfail_ssl_issue5601()
        pytest.importorskip("cryptography")
        security = Security.temporary(
            extra_conn_args={"headers": {"Authorization": "Token abcd"}}
        )
    async with Scheduler(
        protocol=protocol, security=security, dashboard_address=":0"
    ) as s:
        connection_args = security.get_connection_args("worker")
        comm = await connect(s.address, **connection_args)
        assert comm.sock.request.headers.get("Authorization") == "Token abcd"
        await comm.close()
コード例 #10
0
    def __init__(
            self,
            n_workers=0,
            job_cls: Job = None,
            # Cluster keywords
            loop=None,
            security=None,
            shared_temp_directory=None,
            silence_logs="error",
            name=None,
            asynchronous=False,
            # Scheduler-only keywords
            dashboard_address=None,
            host=None,
            scheduler_options=None,
            scheduler_cls=Scheduler,  # Use local scheduler for now
            # Options for both scheduler and workers
        interface=None,
            protocol=None,
            # Job keywords
            config_name=None,
            **job_kwargs):
        self.status = Status.created

        default_job_cls = getattr(type(self), "job_cls", None)
        self.job_cls = default_job_cls
        if job_cls is not None:
            self.job_cls = job_cls

        if self.job_cls is None:
            raise ValueError(
                "You need to specify a Job type. Two cases:\n"
                "- you are inheriting from JobQueueCluster (most likely): you need to add a 'job_cls' class variable "
                "in your JobQueueCluster-derived class {}\n"
                "- you are using JobQueueCluster directly (less likely, only useful for tests): "
                "please explicitly pass a Job type through the 'job_cls' parameter."
                .format(type(self)))

        if dashboard_address is not None:
            raise ValueError(
                "Please pass 'dashboard_address' through 'scheduler_options': use\n"
                'cluster = {0}(..., scheduler_options={{"dashboard_address": ":12345"}}) rather than\n'
                'cluster = {0}(..., dashboard_address="12435")'.format(
                    self.__class__.__name__))

        if host is not None:
            raise ValueError(
                "Please pass 'host' through 'scheduler_options': use\n"
                'cluster = {0}(..., scheduler_options={{"host": "your-host"}}) rather than\n'
                'cluster = {0}(..., host="your-host")'.format(
                    self.__class__.__name__))

        default_config_name = self.job_cls.default_config_name()
        if config_name is None:
            config_name = default_config_name

        if interface is None:
            interface = dask.config.get("jobqueue.%s.interface" % config_name)
        if scheduler_options is None:
            scheduler_options = dask.config.get(
                "jobqueue.%s.scheduler-options" % config_name, {})

        if protocol is None and security is not None:
            protocol = "tls://"

        if security is True:
            try:
                security = Security.temporary()
            except ImportError:
                raise ImportError(
                    "In order to use TLS without pregenerated certificates `cryptography` is required,"
                    "please install it using either pip or conda")

        default_scheduler_options = {
            "protocol": protocol,
            "dashboard_address": ":8787",
            "security": security,
        }
        # scheduler_options overrides parameters common to both workers and scheduler
        scheduler_options = dict(default_scheduler_options,
                                 **scheduler_options)

        # Use the same network interface as the workers if scheduler ip has not
        # been set through scheduler_options via 'host' or 'interface'
        if "host" not in scheduler_options and "interface" not in scheduler_options:
            scheduler_options["interface"] = interface

        scheduler = {
            "cls": scheduler_cls,
            "options": scheduler_options,
        }

        if shared_temp_directory is None:
            shared_temp_directory = dask.config.get(
                "jobqueue.%s.shared-temp-directory" % config_name)
        self.shared_temp_directory = shared_temp_directory

        job_kwargs["config_name"] = config_name
        job_kwargs["interface"] = interface
        job_kwargs["protocol"] = protocol
        job_kwargs["security"] = self._get_worker_security(security)

        self._job_kwargs = job_kwargs

        worker = {"cls": self.job_cls, "options": self._job_kwargs}
        if "processes" in self._job_kwargs and self._job_kwargs[
                "processes"] > 1:
            worker["group"] = [
                "-" + str(i) for i in range(self._job_kwargs["processes"])
            ]

        self._dummy_job  # trigger property to ensure that the job is valid

        super().__init__(
            scheduler=scheduler,
            worker=worker,
            loop=loop,
            security=security,
            silence_logs=silence_logs,
            asynchronous=asynchronous,
            name=name,
        )

        if n_workers:
            self.scale(n_workers)
コード例 #11
0
def test_repr_temp_keys():
    sec = Security.temporary()
    representation = repr(sec)
    assert "Temporary (In-memory)" in representation
コード例 #12
0
    def __init__(
        self,
        n_workers: int = 0,
        worker_class: str = "dask.distributed.Nanny",
        worker_options: dict = {},
        scheduler_options: dict = {},
        docker_image="daskdev/dask:latest",
        docker_args: str = "",
        env_vars: dict = {},
        security: bool = True,
        protocol: str = None,
        **kwargs,
    ):
        if self.scheduler_class is None or self.worker_class is None:
            raise RuntimeError(
                "VMCluster is not intended to be used directly. See docstring for more info."
            )
        self._n_workers = n_workers

        if not security:
            self.security = None
        elif security is True:
            # True indicates self-signed temporary credentials should be used
            self.security = Security.temporary()
        elif not isinstance(security, Security):
            raise TypeError("security must be a Security object")
        else:
            self.security = security

        if protocol is None:
            if self.security and self.security.require_encryption:
                self.protocol = "tls"
            else:
                self.protocol = "tcp"
        else:
            self.protocol = protocol

        if self.security and self.security.require_encryption:
            dask.config.set({
                "distributed.comm.default-scheme":
                self.protocol,
                "distributed.comm.require-encryption":
                True,
                "distributed.comm.tls.ca-file":
                self.security.tls_ca_file,
                "distributed.comm.tls.scheduler.key":
                self.security.tls_scheduler_key,
                "distributed.comm.tls.scheduler.cert":
                self.security.tls_scheduler_cert,
                "distributed.comm.tls.worker.key":
                self.security.tls_worker_key,
                "distributed.comm.tls.worker.cert":
                self.security.tls_worker_cert,
                "distributed.comm.tls.client.key":
                self.security.tls_client_key,
                "distributed.comm.tls.client.cert":
                self.security.tls_client_cert,
            })

        image = self.scheduler_options.get("docker_image",
                                           False) or docker_image
        self.options["docker_image"] = image
        self.scheduler_options["docker_image"] = image
        self.scheduler_options["env_vars"] = env_vars
        self.scheduler_options["protocol"] = protocol
        self.scheduler_options["scheduler_options"] = scheduler_options
        self.worker_options["env_vars"] = env_vars
        self.options["docker_args"] = docker_args
        self.scheduler_options["docker_args"] = docker_args
        self.worker_options["docker_args"] = docker_args
        self.worker_options["docker_image"] = image
        self.worker_options["worker_class"] = worker_class
        self.worker_options["protocol"] = protocol
        self.worker_options["worker_options"] = worker_options
        self.uuid = str(uuid.uuid4())[:8]

        super().__init__(**kwargs, security=self.security)
コード例 #13
0
ファイル: test_ws.py プロジェクト: andersy005/distributed
import pytest

import dask

from distributed import Client, Scheduler, Worker
from distributed.comm import connect, listen, ws
from distributed.comm.core import FatalCommClosedError
from distributed.comm.registry import backends, get_backend
from distributed.security import Security
from distributed.utils_test import (  # noqa: F401
    cleanup, gen_cluster, get_client_ssl_context, get_server_ssl_context, inc,
)

from .test_comms import check_tls_extra

security = Security.temporary()


def test_registered():
    assert "ws" in backends
    backend = get_backend("ws")
    assert isinstance(backend, ws.WSBackend)


@pytest.mark.asyncio
async def test_listen_connect(cleanup):
    async def handle_comm(comm):
        while True:
            msg = await comm.read()
            await comm.write(msg)
コード例 #14
0
def test_repr_temp_keys():
    xfail_ssl_issue5601()
    pytest.importorskip("cryptography")
    sec = Security.temporary()
    representation = repr(sec)
    assert "Temporary (In-memory)" in representation
コード例 #15
0
    def __init__(
        self,
        name=None,
        n_workers=None,
        threads_per_worker=None,
        processes=None,
        loop=None,
        start=None,
        host=None,
        ip=None,
        scheduler_port=0,
        silence_logs=logging.WARN,
        dashboard_address=":8787",
        worker_dashboard_address=None,
        diagnostics_port=None,
        services=None,
        worker_services=None,
        service_kwargs=None,
        asynchronous=False,
        security=None,
        protocol=None,
        blocked_handlers=None,
        interface=None,
        worker_class=None,
        scheduler_kwargs=None,
        scheduler_sync_interval=1,
        **worker_kwargs,
    ):
        if ip is not None:
            # In the future we should warn users about this move
            # warnings.warn("The ip keyword has been moved to host")
            host = ip

        if diagnostics_port is not None:
            warnings.warn("diagnostics_port has been deprecated. "
                          "Please use `dashboard_address=` instead")
            dashboard_address = diagnostics_port

        if threads_per_worker == 0:
            warnings.warn(
                "Setting `threads_per_worker` to 0 has been deprecated. "
                "Please set to None or to a specific int.")
            threads_per_worker = None

        if "dashboard" in worker_kwargs:
            warnings.warn(
                "Setting `dashboard` is discouraged. "
                "Please set `dashboard_address` to affect the scheduler (more common) "
                "and `worker_dashboard_address` for the worker (less common).")

        if processes is None:
            processes = worker_class is None or issubclass(worker_class, Nanny)
        if worker_class is None:
            worker_class = Nanny if processes else Worker

        self.status = None
        self.processes = processes

        if security is None:
            # Falsey values load the default configuration
            security = Security()
        elif security is True:
            # True indicates self-signed temporary credentials should be used
            security = Security.temporary()
        elif not isinstance(security, Security):
            raise TypeError("security must be a Security object")

        if protocol is None:
            if host and "://" in host:
                protocol = host.split("://")[0]
            elif security and security.require_encryption:
                protocol = "tls://"
            elif not self.processes and not scheduler_port:
                protocol = "inproc://"
            else:
                protocol = "tcp://"
        if not protocol.endswith("://"):
            protocol = protocol + "://"

        if host is None and not protocol.startswith(
                "inproc") and not interface:
            host = "127.0.0.1"

        services = services or {}
        worker_services = worker_services or {}
        if n_workers is None and threads_per_worker is None:
            if processes:
                n_workers, threads_per_worker = nprocesses_nthreads()
            else:
                n_workers = 1
                threads_per_worker = CPU_COUNT
        if n_workers is None and threads_per_worker is not None:
            n_workers = max(1, CPU_COUNT //
                            threads_per_worker) if processes else 1
        if n_workers and threads_per_worker is None:
            # Overcommit threads per worker, rather than undercommit
            threads_per_worker = max(1, int(math.ceil(CPU_COUNT / n_workers)))
        if n_workers and "memory_limit" not in worker_kwargs:
            worker_kwargs["memory_limit"] = parse_memory_limit(
                "auto", 1, n_workers)

        worker_kwargs.update({
            "host": host,
            "nthreads": threads_per_worker,
            "services": worker_services,
            "dashboard_address": worker_dashboard_address,
            "dashboard": worker_dashboard_address is not None,
            "interface": interface,
            "protocol": protocol,
            "security": security,
            "silence_logs": silence_logs,
        })

        scheduler = {
            "cls":
            Scheduler,
            "options":
            toolz.merge(
                dict(
                    host=host,
                    services=services,
                    service_kwargs=service_kwargs,
                    security=security,
                    port=scheduler_port,
                    interface=interface,
                    protocol=protocol,
                    dashboard=dashboard_address is not None,
                    dashboard_address=dashboard_address,
                    blocked_handlers=blocked_handlers,
                ),
                scheduler_kwargs or {},
            ),
        }

        worker = {"cls": worker_class, "options": worker_kwargs}
        workers = {i: worker for i in range(n_workers)}

        super().__init__(
            name=name,
            scheduler=scheduler,
            workers=workers,
            worker=worker,
            loop=loop,
            asynchronous=asynchronous,
            silence_logs=silence_logs,
            security=security,
            scheduler_sync_interval=scheduler_sync_interval,
        )
コード例 #16
0
from distributed import Client, Scheduler, Worker
from distributed.comm import connect, listen, ws
from distributed.comm.registry import backends, get_backend
from distributed.security import Security
from distributed.utils_test import (  # noqa: F401
    cleanup,
    gen_cluster,
    get_client_ssl_context,
    get_server_ssl_context,
    inc,
)

from .test_comms import check_tls_extra

security = Security.temporary()


def test_registered():
    assert "ws" in backends
    backend = get_backend("ws")
    assert isinstance(backend, ws.WSBackend)


@pytest.mark.asyncio
async def test_listen_connect(cleanup):
    async def handle_comm(comm):
        while True:
            msg = await comm.read()
            await comm.write(msg)