예제 #1
0
def test_client_requires_valid_keys():
    """Client should not be able to connect to host/port without valid keys."""
    with TemporaryDirectory() as keys, NamedTemporaryFile(dir=keys) as fake:
        port = random.choice(PORT_RANGE)
        client = ZMQSocketBase(zmq.REP)

        with pytest.raises(ClientError,
                           match=r"Failed to load the suite's public "
                           "key, so cannot connect."):
            # Assign a blank file masquerading as a CurveZMQ certificate
            client.start(HOST, port, srv_public_key_loc=fake.name)
async def test_workflow_connect_fail(
    data_store_mgr: DataStoreMgr,
    port_range,
    monkeypatch,
    caplog,
):
    """Simulate a failure during workflow connection.

    The data store should "rollback" any incidental changes made during the
    failed connection attempt by disconnecting from the workflow afterwards.

    Not an ideal test as we don't actually get a communication failure,
    the data store manager just skips contacting the workflow because
    we aren't providing a client for it to connect to, however, probably the
    best we can achieve without actually running a workflow.
    """
    # patch the zmq logic so that the connection doesn't fail at the first
    # hurdle
    monkeypatch.setattr(
        'cylc.flow.network.ZMQSocketBase._socket_bind', lambda *a, **k: None,
    )

    # start a ZMQ REPLY socket in order to claim an unused port
    w_id = Tokens(user='******', workflow='workflow_id').id
    try:
        context = zmq.Context()
        server = ZMQSocketBase(
            zmq.REP,
            context=context,
            workflow=w_id,
            bind=True,
        )
        server._socket_bind(*port_range)

        # register the workflow with the data store
        await data_store_mgr.register_workflow(w_id=w_id, is_active=False)
        contact_data = {
            'name': 'workflow_id',
            'owner': 'cylc',
            CFF.HOST: 'localhost',
            CFF.PORT: server.port,
            CFF.PUBLISH_PORT: server.port,
            CFF.API: 1
        }

        # try to connect to the workflow
        caplog.set_level(logging.DEBUG, data_store_mgr.log.name)
        await data_store_mgr.connect_workflow(w_id, contact_data)

        # the connection should fail because our ZMQ socket is not a
        # WorkflowRuntimeServer with the correct endpoints and auth
        assert [record.message for record in caplog.records] == [
            "[data-store] connect_workflow('~user/workflow_id', <dict>)",
            'failed to connect to ~user/workflow_id',
            "[data-store] disconnect_workflow('~user/workflow_id')",
        ]
    finally:
        # tidy up
        server.stop()
        context.destroy()
예제 #3
0
def test_client_requires_valid_client_private_key():
    """Client should not be able to connect to host/port
    without client private key."""
    port = random.choice(PORT_RANGE)
    client = ZMQSocketBase(zmq.REP, suite=f"test_suite-{time()}")

    with pytest.raises(ClientError, match=r"Failed to find user's private "
                                          r"key, so cannot connect."):
        client.start(HOST, port, srv_public_key_loc="fake_location")

    client.stop()
예제 #4
0
def test_server_requires_valid_keys():
    """Server should not be able to connect to host/port without valid keys."""

    with TemporaryDirectory() as keys, NamedTemporaryFile(dir=keys) as fake:
        # Assign a blank file masquerading as a CurveZMQ certificate
        server = ZMQSocketBase(zmq.REQ, bind=True, daemon=True)

        with pytest.raises(ValueError, match=r"No public key found in "):
            server.start(*PORT_RANGE, private_key_location=fake.name)

        server.stop()
예제 #5
0
def test_server_cannot_start_when_certificate_file_only_contains_public_key():
    """Server should not be able to start when its certificate file does not
    contain the private key."""

    with TemporaryDirectory() as keys:
        pub, _priv = zmq.auth.create_certificates(keys, "server")

        server = ZMQSocketBase(zmq.REQ, bind=True, daemon=True)

        with pytest.raises(
            SuiteServiceFileError,
            match=r"Failed to find server's private key in "
        ):
            server.start(*PORT_RANGE, srv_prv_key_loc=pub)

        server.stop()
예제 #6
0
def test_server_cannot_start_when_server_private_key_cannot_be_loaded():
    """Server should not be able to start when its private key file
    cannot be opened."""
    server = ZMQSocketBase(
        zmq.REQ,
        suite=f"test_suite-{time()}",
        bind=True,
        daemon=True)

    with pytest.raises(
        SuiteServiceFileError,
        match=r"IO error opening server's private key from "
    ):
        server.start(*PORT_RANGE, srv_prv_key_loc="fake_dir/fake_location")

    server.stop()
예제 #7
0
def test_server_cannot_start_when_public_key_not_found_in_certificate_file():
    """Server should not be able to start when its private key file does not
    contain the public key."""

    with TemporaryDirectory() as keys:
        priv_key_loc = os.path.join(keys, "server.key_secret")
        open(priv_key_loc, 'a').close()

        server = ZMQSocketBase(zmq.REQ, bind=True, daemon=True)

        with pytest.raises(
            SuiteServiceFileError,
            match=r"Failed to find server's public key in "
        ):
            server.start(*PORT_RANGE, srv_prv_key_loc=priv_key_loc)

        server.stop()
예제 #8
0
def test_stop():
    """Test socket/thread stop."""
    create_auth_files('test_zmq_stop')  # auth keys are required for comms
    barrier = Barrier(2, timeout=20)
    publisher = ZMQSocketBase(zmq.PUB, suite='test_zmq_stop', bind=True,
                              barrier=barrier, threaded=True, daemon=True)
    publisher.start(*PORT_RANGE)
    # barrier.wait() doesn't seem to work properly here
    # so this workaround will do
    while publisher.barrier.n_waiting < 1:
        sleep(0.2)
    barrier.wait()
    assert not publisher.socket.closed
    assert publisher.thread.is_alive()
    publisher.stop()
    assert publisher.socket.closed
    assert not publisher.thread.is_alive()
예제 #9
0
def test_stop(myflow, port_range):
    """Test socket/thread stop."""
    setup_keys(myflow)  # auth keys are required for comms
    barrier = Barrier(2, timeout=20)
    publisher = ZMQSocketBase(zmq.PUB, workflow=myflow, bind=True,
                              barrier=barrier, threaded=True, daemon=True)
    publisher.start(*port_range)
    # barrier.wait() doesn't seem to work properly here
    # so this workaround will do
    while publisher.barrier.n_waiting < 1:
        sleep(0.1)
    barrier.wait()
    assert not publisher.socket.closed
    assert publisher.thread.is_alive()
    publisher.stop()
    assert publisher.socket.closed
    assert not publisher.thread.is_alive()
예제 #10
0
def test_start():
    """Test socket start."""
    create_auth_files('test_zmq_start')  # auth keys are required for comms
    barrier = Barrier(2, timeout=20)
    publisher = ZMQSocketBase(zmq.PUB, suite='test_zmq_start', bind=True,
                              barrier=barrier, threaded=True, daemon=True)
    assert publisher.barrier.n_waiting == 0
    assert publisher.loop is None
    assert publisher.port is None
    publisher.start(*PORT_RANGE)
    # barrier.wait() doesn't seem to work properly here
    # so this workaround will do
    while publisher.barrier.n_waiting < 1:
        sleep(0.2)
    assert barrier.wait() == 1
    assert publisher.loop is not None
    assert publisher.port is not None
    publisher.stop()
예제 #11
0
def test_start(myflow, port_range):
    """Test socket start."""
    setup_keys(myflow)  # auth keys are required for comms
    barrier = Barrier(2, timeout=20)
    publisher = ZMQSocketBase(zmq.PUB, workflow=myflow, bind=True,
                              barrier=barrier, threaded=True, daemon=True)
    assert publisher.barrier.n_waiting == 0
    assert publisher.loop is None
    assert publisher.port is None
    publisher.start(*port_range)
    # barrier.wait() doesn't seem to work properly here
    # so this workaround will do
    while publisher.barrier.n_waiting < 1:
        sleep(0.2)
    assert barrier.wait() == 1
    assert publisher.loop is not None
    assert publisher.port is not None
    publisher.stop()
예제 #12
0
def test_client_requires_valid_server_public_key_in_private_key_file():
    """Client should not be able to connect to host/port without
    server public key."""
    suite_name = f"test_suite-{time()}"
    port = random.choice(PORT_RANGE)
    client = ZMQSocketBase(zmq.REP, suite=suite_name)

    test_suite_srv_dir = get_suite_srv_dir(reg=suite_name)
    key_info = KeyInfo(
        KeyType.PRIVATE,
        KeyOwner.CLIENT,
        suite_srv_dir=test_suite_srv_dir)
    directory = os.path.expanduser("~/cylc-run")
    tmpdir = os.path.join(directory, suite_name)
    os.makedirs(key_info.key_path, exist_ok=True)

    _pub, _priv = zmq.auth.create_certificates(key_info.key_path, "client")

    with pytest.raises(ClientError, match=r"Failed to load the suite's public "
                                          r"key, so cannot connect."):
        client.start(HOST, port, srv_public_key_loc="fake_location")

    client.stop()
    rmtree(tmpdir, ignore_errors=True)
예제 #13
0
def test_single_port():
    """Test server on a single port and port in use exception."""
    context = zmq.Context()
    create_auth_files('test_zmq')  # auth keys are required for comms
    serv1 = ZMQSocketBase(
        zmq.REP, context=context, suite='test_zmq', bind=True)
    serv2 = ZMQSocketBase(
        zmq.REP, context=context, suite='test_zmq', bind=True)

    serv1._socket_bind(*PORT_RANGE)
    port = serv1.port

    with pytest.raises(CylcError, match=r"Address already in use") as exc:
        serv2._socket_bind(port, port)

    serv2.stop()
    serv1.stop()
    context.destroy()
예제 #14
0
def test_single_port(myflow, port_range):
    """Test server on a single port and port in use exception."""
    context = zmq.Context()
    setup_keys(myflow)  # auth keys are required for comms
    serv1 = ZMQSocketBase(zmq.REP, context=context, suite=myflow, bind=True)
    serv2 = ZMQSocketBase(zmq.REP, context=context, suite=myflow, bind=True)

    serv1._socket_bind(*port_range)
    port = serv1.port

    with pytest.raises(CylcError, match=r"Address already in use"):
        serv2._socket_bind(port, port)

    serv2.stop()
    serv1.stop()
    context.destroy()