Esempio n. 1
0
async def test_tls_reject_certificate():
    cli_ctx = get_client_ssl_context()
    serv_ctx = get_server_ssl_context()

    # These certs are not signed by our test CA
    bad_cert_key = ("tls-self-signed-cert.pem", "tls-self-signed-key.pem")
    bad_cli_ctx = get_client_ssl_context(*bad_cert_key)
    bad_serv_ctx = get_server_ssl_context(*bad_cert_key)

    async def handle_comm(comm):
        scheme, loc = parse_address(comm.peer_address)
        assert scheme == "tls"
        await comm.close()

    # Listener refuses a connector not signed by the CA
    listener = listen("tls://", handle_comm, connection_args={"ssl_context": serv_ctx})
    await listener.start()

    with pytest.raises(EnvironmentError) as excinfo:
        comm = await connect(
            listener.contact_address,
            timeout=0.5,
            connection_args={"ssl_context": bad_cli_ctx},
        )
        await comm.write({"x": "foo"})  # TODO: why is this necessary in Tornado 6 ?

    # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028
    if sys.version_info >= (3,) and os.name != "nt":
        try:
            # See https://serverfault.com/questions/793260/what-does-tlsv1-alert-unknown-ca-mean
            assert "unknown ca" in str(excinfo.value)
        except AssertionError:
            if os.name == "nt":
                assert "An existing connection was forcibly closed" in str(
                    excinfo.value
                )
            else:
                raise

    # Sanity check
    comm = await connect(
        listener.contact_address, timeout=2, connection_args={"ssl_context": cli_ctx}
    )
    await comm.close()

    # Connector refuses a listener not signed by the CA
    listener = listen(
        "tls://", handle_comm, connection_args={"ssl_context": bad_serv_ctx}
    )
    await listener.start()

    with pytest.raises(EnvironmentError) as excinfo:
        await connect(
            listener.contact_address,
            timeout=2,
            connection_args={"ssl_context": cli_ctx},
        )
    # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028
    if sys.version_info >= (3,):
        assert "certificate verify failed" in str(excinfo.value)
Esempio n. 2
0
async def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """

    async def handle_comm(comm):
        comm.abort()

    c = {
        "distributed.comm.tls.ca-file": ca_file,
        "distributed.comm.tls.scheduler.key": key1,
        "distributed.comm.tls.scheduler.cert": cert1,
        "distributed.comm.tls.worker.cert": keycert1,
    }
    with dask.config.set(c):
        sec = Security()

    c["distributed.comm.require-encryption"] = True
    with dask.config.set(c):
        sec2 = Security()

    for listen_addr in ["inproc://", "tls://"]:
        async with listen(
            listen_addr, handle_comm, **sec.get_listen_args("scheduler")
        ) as listener:
            comm = await connect(
                listener.contact_address, **sec2.get_connection_args("worker")
            )
            comm.abort()

        async with listen(
            listen_addr, handle_comm, **sec2.get_listen_args("scheduler")
        ) as listener:
            comm = await connect(
                listener.contact_address, **sec2.get_connection_args("worker")
            )
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ["tcp://"]:
        async with listen(
            listen_addr, handle_comm, **sec.get_listen_args("scheduler")
        ) as listener:
            comm = await connect(
                listener.contact_address, **sec.get_connection_args("worker")
            )
            comm.abort()

            with pytest.raises(RuntimeError):
                await connect(
                    listener.contact_address, **sec2.get_connection_args("worker")
                )

        with pytest.raises(RuntimeError):
            listen(listen_addr, handle_comm, **sec2.get_listen_args("scheduler"))
Esempio n. 3
0
async def run_bench(protocol, nbytes, niters):
    data = np.random.randint(0, 255, size=nbytes, dtype=np.uint8)
    item = Serialized(*serialize(data))

    if protocol == 'tcp':
        listener = listen('tcp://127.0.0.1',
                          server_handle_comm,
                          deserialize=False)
    else:
        listener = listen('ucx://' + ucp.get_address(),
                          server_handle_comm,
                          deserialize=False)
    listener.start()

    start = clock()

    for i in range(niters):
        comm = await connect(listener.contact_address, deserialize=False)
        await comm.write({'op': 'ping', 'item': item})
        msg = await comm.read()
        assert msg['op'] == 'pong'
        assert isinstance(msg['item'], Serialized)
        await comm.close()
        print('.', end='', flush=True)
    print()

    end = clock()

    listener.stop()

    dt = end - start
    rate = len(data) * niters / dt
    print("duration: %s => rate: %d MB/s" % (dt, rate / 1e6))
Esempio n. 4
0
def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    @gen.coroutine
    def handle_comm(comm):
        comm.abort()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()
    c['require-encryption'] = True
    with new_config(c):
        sec2 = Security()

    for listen_addr in ['inproc://', 'tls://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

        with listen(listen_addr, handle_comm,
                    connection_args=sec2.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ['tcp://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec.get_connection_args('worker'))
            comm.abort()

            with pytest.raises(RuntimeError):
                yield connect(listener.contact_address,
                              connection_args=sec2.get_connection_args('worker'))

        with pytest.raises(RuntimeError):
            listen(listen_addr, handle_comm,
                   connection_args=sec2.get_listen_args('scheduler'))
def test_require_encryption():
    """
    Functional test for "require_encryption" setting.
    """
    @gen.coroutine
    def handle_comm(comm):
        comm.abort()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
                },
            'worker': {
                'cert': keycert1,
                },
            },
        }
    with new_config(c):
        sec = Security()
    c['require-encryption'] = True
    with new_config(c):
        sec2 = Security()

    for listen_addr in ['inproc://', 'tls://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

        with listen(listen_addr, handle_comm,
                    connection_args=sec2.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec2.get_connection_args('worker'))
            comm.abort()

    @contextmanager
    def check_encryption_error():
        with pytest.raises(RuntimeError) as excinfo:
            yield
        assert "encryption required" in str(excinfo.value)

    for listen_addr in ['tcp://']:
        with listen(listen_addr, handle_comm,
                    connection_args=sec.get_listen_args('scheduler')) as listener:
            comm = yield connect(listener.contact_address,
                                 connection_args=sec.get_connection_args('worker'))
            comm.abort()

            with pytest.raises(RuntimeError):
                yield connect(listener.contact_address,
                              connection_args=sec2.get_connection_args('worker'))

        with pytest.raises(RuntimeError):
            listen(listen_addr, handle_comm,
                   connection_args=sec2.get_listen_args('scheduler'))
Esempio n. 6
0
def test_tls_reject_certificate():
    cli_ctx = get_client_ssl_context()
    serv_ctx = get_server_ssl_context()

    # These certs are not signed by our test CA
    bad_cert_key = ('tls-self-signed-cert.pem', 'tls-self-signed-key.pem')
    bad_cli_ctx = get_client_ssl_context(*bad_cert_key)
    bad_serv_ctx = get_server_ssl_context(*bad_cert_key)

    @gen.coroutine
    def handle_comm(comm):
        scheme, loc = parse_address(comm.peer_address)
        assert scheme == 'tls'
        yield comm.close()

    # Listener refuses a connector not signed by the CA
    listener = listen('tls://',
                      handle_comm,
                      connection_args={'ssl_context': serv_ctx})
    listener.start()

    with pytest.raises(EnvironmentError) as excinfo:
        yield connect(listener.contact_address,
                      timeout=0.5,
                      connection_args={'ssl_context': bad_cli_ctx})

    # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028
    if sys.version_info >= (3, ) and os.name != 'nt':
        try:
            # See https://serverfault.com/questions/793260/what-does-tlsv1-alert-unknown-ca-mean
            assert "unknown ca" in str(excinfo.value)
        except AssertionError:
            if os.name == 'nt':
                assert "An existing connection was forcibly closed" in str(
                    excinfo.value)
            else:
                raise

    # Sanity check
    comm = yield connect(listener.contact_address,
                         timeout=0.5,
                         connection_args={'ssl_context': cli_ctx})
    yield comm.close()

    # Connector refuses a listener not signed by the CA
    listener = listen('tls://',
                      handle_comm,
                      connection_args={'ssl_context': bad_serv_ctx})
    listener.start()

    with pytest.raises(EnvironmentError) as excinfo:
        yield connect(listener.contact_address,
                      timeout=0.5,
                      connection_args={'ssl_context': cli_ctx})
    # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028
    if sys.version_info >= (3, ):
        assert "certificate verify failed" in str(excinfo.value)
Esempio n. 7
0
def test_tls_reject_certificate():
    cli_ctx = get_client_ssl_context()
    serv_ctx = get_server_ssl_context()

    # These certs are not signed by our test CA
    bad_cert_key = ('tls-self-signed-cert.pem', 'tls-self-signed-key.pem')
    bad_cli_ctx = get_client_ssl_context(*bad_cert_key)
    bad_serv_ctx = get_server_ssl_context(*bad_cert_key)

    @gen.coroutine
    def handle_comm(comm):
        scheme, loc = parse_address(comm.peer_address)
        assert scheme == 'tls'
        yield comm.close()

    # Listener refuses a connector not signed by the CA
    listener = listen('tls://', handle_comm,
                      connection_args={'ssl_context': serv_ctx})
    listener.start()

    with pytest.raises(EnvironmentError) as excinfo:
        yield connect(listener.contact_address, timeout=0.5,
                      connection_args={'ssl_context': bad_cli_ctx})

    # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028
    if sys.version_info >= (3,) and os.name != 'nt':
        try:
            # See https://serverfault.com/questions/793260/what-does-tlsv1-alert-unknown-ca-mean
            assert "unknown ca" in str(excinfo.value)
        except AssertionError:
            if os.name == 'nt':
                assert "An existing connection was forcibly closed" in str(excinfo.value)
            else:
                raise

    # Sanity check
    comm = yield connect(listener.contact_address, timeout=0.5,
                         connection_args={'ssl_context': cli_ctx})
    yield comm.close()

    # Connector refuses a listener not signed by the CA
    listener = listen('tls://', handle_comm,
                      connection_args={'ssl_context': bad_serv_ctx})
    listener.start()

    with pytest.raises(EnvironmentError) as excinfo:
        yield connect(listener.contact_address, timeout=0.5,
                      connection_args={'ssl_context': cli_ctx})
    # The wrong error is reported on Python 2, see https://github.com/tornadoweb/tornado/pull/2028
    if sys.version_info >= (3,):
        assert "certificate verify failed" in str(excinfo.value)
Esempio n. 8
0
 def _main(self, address):
     listener = listen(address, self._handle_comm)
     listener.start()
     yield [
         self._connect_close(listener.contact_address)
         for i in range(self.N_CONNECTS)
     ]
     listener.stop()
Esempio n. 9
0
async def test_expect_ssl_context(cleanup):
    server_ctx = get_server_ssl_context()

    async with listen("wss://", lambda comm: comm,
                      ssl_context=server_ctx) as listener:
        with pytest.raises(FatalCommClosedError,
                           match="TLS expects a `ssl_context` *"):
            comm = await connect(listener.contact_address)
Esempio n. 10
0
def test_tls_listen_connect():
    """
    Functional test for TLS connection args.
    """
    @gen.coroutine
    def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith("tls://")
        yield comm.write("hello")
        yield comm.close()

    c = {
        "tls": {
            "ca-file": ca_file,
            "scheduler": {
                "key": key1,
                "cert": cert1
            },
            "worker": {
                "cert": keycert1
            },
        }
    }
    with new_config(c):
        sec = Security()

    c["tls"]["ciphers"] = FORCED_CIPHER
    with new_config(c):
        forced_cipher_sec = Security()

    with listen("tls://",
                handle_comm,
                connection_args=sec.get_listen_args("scheduler")) as listener:
        comm = yield connect(listener.contact_address,
                             connection_args=sec.get_connection_args("worker"))
        msg = yield comm.read()
        assert msg == "hello"
        comm.abort()

        # No SSL context for client
        with pytest.raises(TypeError):
            yield connect(
                listener.contact_address,
                connection_args=sec.get_connection_args("client"),
            )

        # Check forced cipher
        comm = yield connect(
            listener.contact_address,
            connection_args=forced_cipher_sec.get_connection_args("worker"),
        )
        cipher, _, _, = comm.extra_info["cipher"]
        assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS
        comm.abort()
Esempio n. 11
0
 def _main(self, address, obj, n_transfers, **kwargs):
     listener = listen(address, partial(self._handle_comm, n_transfers), **kwargs)
     yield listener.start()
     comm = yield connect(listener.contact_address, **kwargs)
     for i in range(n_transfers):
         yield comm.write(obj)
     # Read back to ensure that the round-trip is complete
     for i in range(n_transfers):
         yield comm.read()
     yield comm.close()
     listener.stop()
Esempio n. 12
0
async def check_deserialize_eoferror(addr):
    """
    EOFError when deserializing should close the comm.
    """
    async def handle_comm(comm):
        await comm.write({"data": to_serialize(_EOFRaising())})
        with pytest.raises(CommClosedError):
            await comm.read()

    async with listen(addr, handle_comm) as listener:
        comm = await connect(listener.contact_address, deserialize=deserialize)
        with pytest.raises(CommClosedError):
            await comm.read()
Esempio n. 13
0
async def test_inproc_handshakes_concurrently():
    async def handle_comm():
        pass

    async with listen("inproc://", handle_comm) as listener:
        addr = listener.listen_address
        scheme, loc = parse_address(addr)
        connector = get_backend(scheme).get_connector()

        comm1 = await connector.connect(loc)
        comm2 = await connector.connect(loc)
        await comm1.close()
        await comm2.close()
Esempio n. 14
0
def get_comm_pair(listen_addr, listen_args=None, connect_args=None):
    q = queues.Queue()

    def handle_comm(comm):
        q.put(comm)

    listener = listen(listen_addr, handle_comm, connection_args=listen_args)
    listener.start()

    comm = yield connect(listener.contact_address,
                         connection_args=connect_args)
    serv_comm = yield q.get()
    raise gen.Return((comm, serv_comm))
Esempio n. 15
0
async def get_comm_pair(
    listen_addr="ucx://" + HOST, listen_args={}, connect_args={}, **kwargs
):
    q = asyncio.queues.Queue()

    async def handle_comm(comm):
        await q.put(comm)

    listener = listen(listen_addr, handle_comm, **listen_args, **kwargs)
    async with listener:
        comm = await connect(listener.contact_address, **connect_args, **kwargs)
        serv_comm = await q.get()
        return (comm, serv_comm)
Esempio n. 16
0
async def test_listen_connect(cleanup):
    async def handle_comm(comm):
        while True:
            msg = await comm.read()
            await comm.write(msg)

    async with listen("ws://", handle_comm) as listener:
        comm = await connect(listener.contact_address)
        await comm.write(b"Hello!")
        result = await comm.read()
        assert result == b"Hello!"

        await comm.close()
Esempio n. 17
0
def check_many_listeners(addr):
    @gen.coroutine
    def handle_comm(comm):
        pass

    listeners = []
    for i in range(100):
        listener = listen(addr, handle_comm)
        listener.start()
        listeners.append(listener)

    for listener in listeners:
        listener.stop()
Esempio n. 18
0
def test_tls_listen_connect():
    """
    Functional test for TLS connection args.
    """
    @gen.coroutine
    def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith('tls://')
        yield comm.write('hello')
        yield comm.close()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    c['tls']['ciphers'] = FORCED_CIPHER
    with new_config(c):
        forced_cipher_sec = Security()

    with listen('tls://',
                handle_comm,
                connection_args=sec.get_listen_args('scheduler')) as listener:
        comm = yield connect(listener.contact_address,
                             connection_args=sec.get_connection_args('worker'))
        msg = yield comm.read()
        assert msg == 'hello'
        comm.abort()

        # No SSL context for client
        with pytest.raises(TypeError):
            yield connect(listener.contact_address,
                          connection_args=sec.get_connection_args('client'))

        # Check forced cipher
        comm = yield connect(
            listener.contact_address,
            connection_args=forced_cipher_sec.get_connection_args('worker'))
        cipher, _, _, = comm.extra_info['cipher']
        assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS
        comm.abort()
Esempio n. 19
0
async def test_inproc_continues_listening_after_handshake_error():
    async def handle_comm():
        pass

    async with listen("inproc://", handle_comm) as listener:
        addr = listener.listen_address
        scheme, loc = parse_address(addr)
        connector = get_backend(scheme).get_connector()

        comm = await connector.connect(loc)
        await comm.close()

        comm = await connector.connect(loc)
        await comm.close()
Esempio n. 20
0
def test_inproc_comm_closed_explicit_2():
    listener_errors = []

    @gen.coroutine
    def handle_comm(comm):
        # Wait
        try:
            yield comm.read()
        except CommClosedError:
            assert comm.closed()
            listener_errors.append(True)
        else:
            comm.close()

    listener = listen('inproc://', handle_comm)
    listener.start()
    contact_addr = listener.contact_address

    comm = yield connect(contact_addr)
    comm.close()
    assert comm.closed()
    start = time()
    while len(listener_errors) < 1:
        assert time() < start + 1
        yield gen.sleep(0.01)
    assert len(listener_errors) == 1

    with pytest.raises(CommClosedError):
        yield comm.read()
    with pytest.raises(CommClosedError):
        yield comm.write("foo")

    comm = yield connect(contact_addr)
    comm.write("foo")
    with pytest.raises(CommClosedError):
        yield comm.read()
    with pytest.raises(CommClosedError):
        yield comm.write("foo")
    assert comm.closed()

    comm = yield connect(contact_addr)
    comm.write("foo")

    start = time()
    while not comm.closed():
        yield gen.sleep(0.01)
        assert time() < start + 2

    comm.close()
    comm.close()
Esempio n. 21
0
async def get_comm_pair(listen_addr, listen_args=None, connect_args=None, **kwargs):
    q = queues.Queue()

    def handle_comm(comm):
        q.put(comm)

    listener = listen(listen_addr, handle_comm, connection_args=listen_args, **kwargs)
    await listener.start()

    comm = await connect(
        listener.contact_address, connection_args=connect_args, **kwargs
    )
    serv_comm = await q.get()
    return (comm, serv_comm)
Esempio n. 22
0
def check_deserialize_eoferror(addr):
    """
    EOFError when deserializing should close the comm.
    """
    @gen.coroutine
    def handle_comm(comm):
        yield comm.write({'data': to_serialize(_EOFRaising())})
        with pytest.raises(CommClosedError):
            yield comm.read()

    with listen(addr, handle_comm) as listener:
        comm = yield connect(listener.contact_address, deserialize=deserialize)
        with pytest.raises(CommClosedError):
            yield comm.read()
Esempio n. 23
0
def test_inproc_comm_closed_explicit_2():
    listener_errors = []

    @gen.coroutine
    def handle_comm(comm):
        # Wait
        try:
            yield comm.read()
        except CommClosedError:
            assert comm.closed()
            listener_errors.append(True)
        else:
            yield comm.close()

    listener = listen("inproc://", handle_comm)
    listener.start()
    contact_addr = listener.contact_address

    comm = yield connect(contact_addr)
    yield comm.close()
    assert comm.closed()
    start = time()
    while len(listener_errors) < 1:
        assert time() < start + 1
        yield gen.sleep(0.01)
    assert len(listener_errors) == 1

    with pytest.raises(CommClosedError):
        yield comm.read()
    with pytest.raises(CommClosedError):
        yield comm.write("foo")

    comm = yield connect(contact_addr)
    yield comm.write("foo")
    with pytest.raises(CommClosedError):
        yield comm.read()
    with pytest.raises(CommClosedError):
        yield comm.write("foo")
    assert comm.closed()

    comm = yield connect(contact_addr)
    yield comm.write("foo")

    start = time()
    while not comm.closed():
        yield gen.sleep(0.01)
        assert time() < start + 2

    yield comm.close()
    yield comm.close()
Esempio n. 24
0
def check_deserialize_eoferror(addr):
    """
    EOFError when deserializing should close the comm.
    """
    @gen.coroutine
    def handle_comm(comm):
        yield comm.write({'data': to_serialize(_EOFRaising())})
        with pytest.raises(CommClosedError):
            yield comm.read()

    with listen(addr, handle_comm) as listener:
        comm = yield connect(listener.contact_address, deserialize=deserialize)
        with pytest.raises(CommClosedError):
            yield comm.read()
Esempio n. 25
0
def test_tls_listen_connect():
    """
    Functional test for TLS connection args.
    """
    @gen.coroutine
    def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith('tls://')
        yield comm.write('hello')
        yield comm.close()

    c = {
        'tls': {
            'ca-file': ca_file,
            'scheduler': {
                'key': key1,
                'cert': cert1,
            },
            'worker': {
                'cert': keycert1,
            },
        },
    }
    with new_config(c):
        sec = Security()

    c['tls']['ciphers'] = FORCED_CIPHER
    with new_config(c):
        forced_cipher_sec = Security()

    with listen('tls://', handle_comm,
                connection_args=sec.get_listen_args('scheduler')) as listener:
        comm = yield connect(listener.contact_address,
                             connection_args=sec.get_connection_args('worker'))
        msg = yield comm.read()
        assert msg == 'hello'
        comm.abort()

        # No SSL context for client
        with pytest.raises(TypeError):
            yield connect(listener.contact_address,
                          connection_args=sec.get_connection_args('client'))

        # Check forced cipher
        comm = yield connect(listener.contact_address,
                             connection_args=forced_cipher_sec.get_connection_args('worker'))
        cipher, _, _, = comm.extra_info['cipher']
        assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS
        comm.abort()
Esempio n. 26
0
def get_comm_pair(listen_addr, listen_args=None, connect_args=None,
                  **kwargs):
    q = queues.Queue()

    def handle_comm(comm):
        q.put(comm)

    listener = listen(listen_addr, handle_comm,
                      connection_args=listen_args, **kwargs)
    listener.start()

    comm = yield connect(listener.contact_address,
                         connection_args=connect_args, **kwargs)
    serv_comm = yield q.get()
    raise gen.Return((comm, serv_comm))
Esempio n. 27
0
async def check_connector_deserialize(addr, deserialize, in_value, check_out):
    done = asyncio.Event()

    async def handle_comm(comm):
        await comm.write(in_value)
        await done.wait()
        await comm.close()

    async with listen(addr, handle_comm) as listener:
        comm = await connect(listener.contact_address, deserialize=deserialize)

    out_value = await comm.read()
    done.set()
    await comm.close()
    check_out(out_value)
Esempio n. 28
0
async def test_tls_listen_connect():
    """
    Functional test for TLS connection args.
    """
    async def handle_comm(comm):
        peer_addr = comm.peer_address
        assert peer_addr.startswith("tls://")
        await comm.write("hello")
        await comm.close()

    c = {
        "distributed.comm.tls.ca-file": ca_file,
        "distributed.comm.tls.scheduler.key": key1,
        "distributed.comm.tls.scheduler.cert": cert1,
        "distributed.comm.tls.worker.cert": keycert1,
    }
    with dask.config.set(c):
        sec = Security()

    c["distributed.comm.tls.ciphers"] = FORCED_CIPHER
    with dask.config.set(c):
        forced_cipher_sec = Security()

    async with listen(
            "tls://",
            handle_comm,
            connection_args=sec.get_listen_args("scheduler")) as listener:
        comm = await connect(listener.contact_address,
                             connection_args=sec.get_connection_args("worker"))
        msg = await comm.read()
        assert msg == "hello"
        comm.abort()

        # No SSL context for client
        with pytest.raises(TypeError):
            await connect(
                listener.contact_address,
                connection_args=sec.get_connection_args("client"),
            )

        # Check forced cipher
        comm = await connect(
            listener.contact_address,
            connection_args=forced_cipher_sec.get_connection_args("worker"),
        )
        cipher, _, _ = comm.extra_info["cipher"]
        assert cipher in [FORCED_CIPHER] + TLS_13_CIPHERS
        comm.abort()
Esempio n. 29
0
def check_connector_deserialize(addr, deserialize, in_value, check_out):
    done = locks.Event()

    @gen.coroutine
    def handle_comm(comm):
        yield comm.write(in_value)
        yield done.wait()
        yield comm.close()

    with listen(addr, handle_comm) as listener:
        comm = yield connect(listener.contact_address, deserialize=deserialize)

    out_value = yield comm.read()
    done.set()
    yield comm.close()
    check_out(out_value)
Esempio n. 30
0
def check_connector_deserialize(addr, deserialize, in_value, check_out):
    done = locks.Event()

    @gen.coroutine
    def handle_comm(comm):
        yield comm.write(in_value)
        yield done.wait()
        yield comm.close()

    with listen(addr, handle_comm) as listener:
        comm = yield connect(listener.contact_address, deserialize=deserialize)

    out_value = yield comm.read()
    done.set()
    yield comm.close()
    check_out(out_value)
Esempio n. 31
0
async def check_listener_deserialize(addr, deserialize, in_value, check_out):
    q = asyncio.Queue()

    async def handle_comm(comm):
        msg = await comm.read()
        q.put_nowait(msg)
        await comm.close()

    async with listen(addr, handle_comm, deserialize=deserialize) as listener:
        comm = await connect(listener.contact_address)

    await comm.write(in_value)

    out_value = await q.get()
    check_out(out_value)
    await comm.close()
Esempio n. 32
0
def check_comm_closed_implicit(addr, delay=None):
    @gen.coroutine
    def handle_comm(comm):
        yield comm.close()

    listener = listen(addr, handle_comm)
    listener.start()
    contact_addr = listener.contact_address

    comm = yield connect(contact_addr)
    with pytest.raises(CommClosedError):
        yield comm.write({})

    comm = yield connect(contact_addr)
    with pytest.raises(CommClosedError):
        yield comm.read()
Esempio n. 33
0
def check_connector_deserialize(addr, deserialize, in_value, check_out):
    q = queues.Queue()

    @gen.coroutine
    def handle_comm(comm):
        msg = yield q.get()
        yield comm.write(msg)
        yield comm.close()

    with listen(addr, handle_comm) as listener:
        comm = yield connect(listener.contact_address, deserialize=deserialize)

    q.put_nowait(in_value)
    out_value = yield comm.read()
    yield comm.close()
    check_out(out_value)
Esempio n. 34
0
def check_listener_deserialize(addr, deserialize, in_value, check_out):
    q = queues.Queue()

    @gen.coroutine
    def handle_comm(comm):
        msg = yield comm.read()
        q.put_nowait(msg)
        yield comm.close()

    with listen(addr, handle_comm, deserialize=deserialize) as listener:
        comm = yield connect(listener.contact_address)

    yield comm.write(in_value)

    out_value = yield q.get()
    check_out(out_value)
    yield comm.close()
Esempio n. 35
0
async def check_comm_closed_implicit(
    addr, delay=None, listen_args=None, connect_args=None
):
    async def handle_comm(comm):
        await comm.close()

    listener = listen(addr, handle_comm, connection_args=listen_args)
    await listener.start()
    contact_addr = listener.contact_address

    comm = await connect(contact_addr, connection_args=connect_args)
    with pytest.raises(CommClosedError):
        await comm.write({})

    comm = await connect(contact_addr, connection_args=connect_args)
    with pytest.raises(CommClosedError):
        await comm.read()
Esempio n. 36
0
def check_comm_closed_implicit(addr, delay=None, listen_args=None,
                               connect_args=None):
    @gen.coroutine
    def handle_comm(comm):
        yield comm.close()

    listener = listen(addr, handle_comm, connection_args=listen_args)
    listener.start()
    contact_addr = listener.contact_address

    comm = yield connect(contact_addr, connection_args=connect_args)
    with pytest.raises(CommClosedError):
        yield comm.write({})

    comm = yield connect(contact_addr, connection_args=connect_args)
    with pytest.raises(CommClosedError):
        yield comm.read()
Esempio n. 37
0
async def test_listen_connect_wss():
    async def handle_comm(comm):
        while True:
            msg = await comm.read()
            await comm.write(msg)

    server_ctx = get_server_ssl_context()
    client_ctx = get_client_ssl_context()

    async with listen("wss://", handle_comm, ssl_context=server_ctx) as listener:
        comm = await connect(listener.contact_address, ssl_context=client_ctx)
        assert comm.peer_address.startswith("wss://")
        check_tls_extra(comm.extra_info)
        await comm.write(b"Hello!")
        result = await comm.read()
        assert result == b"Hello!"
        await comm.close()
Esempio n. 38
0
async def check_many_listeners(addr):
    async def handle_comm(comm):
        pass

    listeners = []
    N = 100

    for i in range(N):
        listener = listen(addr, handle_comm)
        await listener.start()
        listeners.append(listener)

    assert len(set(l.listen_address for l in listeners)) == N
    assert len(set(l.contact_address for l in listeners)) == N

    for listener in listeners:
        listener.stop()
Esempio n. 39
0
async def check_comm_closed_implicit(addr,
                                     delay=None,
                                     listen_args={},
                                     connect_args={}):
    async def handle_comm(comm):
        await comm.close()

    async with listen(addr, handle_comm, **listen_args) as listener:

        comm = await connect(listener.contact_address, **connect_args)
        with pytest.raises(CommClosedError):
            await comm.write({})
            await comm.read()

        comm = await connect(listener.contact_address, **connect_args)
        with pytest.raises(CommClosedError):
            await comm.read()
Esempio n. 40
0
def check_many_listeners(addr):
    @gen.coroutine
    def handle_comm(comm):
        pass

    listeners = []
    N = 100

    for i in range(N):
        listener = listen(addr, handle_comm)
        listener.start()
        listeners.append(listener)

    assert len(set(l.listen_address for l in listeners)) == N
    assert len(set(l.contact_address for l in listeners)) == N

    for listener in listeners:
        listener.stop()
Esempio n. 41
0
def check_client_server(addr, check_listen_addr=None, check_contact_addr=None,
                        listen_args=None, connect_args=None):
    """
    Abstract client / server test.
    """
    @gen.coroutine
    def handle_comm(comm):
        scheme, loc = parse_address(comm.peer_address)
        assert scheme == bound_scheme

        msg = yield comm.read()
        assert msg['op'] == 'ping'
        msg['op'] = 'pong'
        yield comm.write(msg)

        msg = yield comm.read()
        assert msg['op'] == 'foobar'

        yield comm.close()

    # Arbitrary connection args should be ignored
    listen_args = listen_args or {'xxx': 'bar'}
    connect_args = connect_args or {'xxx': 'foo'}

    listener = listen(addr, handle_comm, connection_args=listen_args)
    listener.start()

    # Check listener properties
    bound_addr = listener.listen_address
    bound_scheme, bound_loc = parse_address(bound_addr)
    assert bound_scheme in ('inproc', 'tcp', 'tls')
    assert bound_scheme == parse_address(addr)[0]

    if check_listen_addr is not None:
        check_listen_addr(bound_loc)

    contact_addr = listener.contact_address
    contact_scheme, contact_loc = parse_address(contact_addr)
    assert contact_scheme == bound_scheme

    if check_contact_addr is not None:
        check_contact_addr(contact_loc)
    else:
        assert contact_addr == bound_addr

    # Check client <-> server comms
    l = []

    @gen.coroutine
    def client_communicate(key, delay=0):
        comm = yield connect(listener.contact_address,
                             connection_args=connect_args)
        assert comm.peer_address == listener.contact_address

        yield comm.write({'op': 'ping', 'data': key})
        yield comm.write({'op': 'foobar'})
        if delay:
            yield gen.sleep(delay)
        msg = yield comm.read()
        assert msg == {'op': 'pong', 'data': key}
        l.append(key)
        yield comm.close()

    yield client_communicate(key=1234)

    # Many clients at once
    futures = [client_communicate(key=i, delay=0.05) for i in range(20)]
    yield futures
    assert set(l) == {1234} | set(range(20))

    listener.stop()