async def test_workflow_connect_fail( data_store_mgr: DataStoreMgr, port_range, monkeypatch, caplog, ): """Simulate a failure during workflow connection. The data store should "rollback" any incidental changes made during the failed connection attempt by disconnecting from the workflow afterwards. Not an ideal test as we don't actually get a communication failure, the data store manager just skips contacting the workflow because we aren't providing a client for it to connect to, however, probably the best we can achieve without actually running a workflow. """ # patch the zmq logic so that the connection doesn't fail at the first # hurdle monkeypatch.setattr( 'cylc.flow.network.ZMQSocketBase._socket_bind', lambda *a, **k: None, ) # start a ZMQ REPLY socket in order to claim an unused port w_id = Tokens(user='******', workflow='workflow_id').id try: context = zmq.Context() server = ZMQSocketBase( zmq.REP, context=context, workflow=w_id, bind=True, ) server._socket_bind(*port_range) # register the workflow with the data store await data_store_mgr.register_workflow(w_id=w_id, is_active=False) contact_data = { 'name': 'workflow_id', 'owner': 'cylc', CFF.HOST: 'localhost', CFF.PORT: server.port, CFF.PUBLISH_PORT: server.port, CFF.API: 1 } # try to connect to the workflow caplog.set_level(logging.DEBUG, data_store_mgr.log.name) await data_store_mgr.connect_workflow(w_id, contact_data) # the connection should fail because our ZMQ socket is not a # WorkflowRuntimeServer with the correct endpoints and auth assert [record.message for record in caplog.records] == [ "[data-store] connect_workflow('~user/workflow_id', <dict>)", 'failed to connect to ~user/workflow_id', "[data-store] disconnect_workflow('~user/workflow_id')", ] finally: # tidy up server.stop() context.destroy()
def test_client_requires_valid_client_private_key(): """Client should not be able to connect to host/port without client private key.""" port = random.choice(PORT_RANGE) client = ZMQSocketBase(zmq.REP, suite=f"test_suite-{time()}") with pytest.raises(ClientError, match=r"Failed to find user's private " r"key, so cannot connect."): client.start(HOST, port, srv_public_key_loc="fake_location") client.stop()
def test_server_requires_valid_keys(): """Server should not be able to connect to host/port without valid keys.""" with TemporaryDirectory() as keys, NamedTemporaryFile(dir=keys) as fake: # Assign a blank file masquerading as a CurveZMQ certificate server = ZMQSocketBase(zmq.REQ, bind=True, daemon=True) with pytest.raises(ValueError, match=r"No public key found in "): server.start(*PORT_RANGE, private_key_location=fake.name) server.stop()
def test_server_cannot_start_when_certificate_file_only_contains_public_key(): """Server should not be able to start when its certificate file does not contain the private key.""" with TemporaryDirectory() as keys: pub, _priv = zmq.auth.create_certificates(keys, "server") server = ZMQSocketBase(zmq.REQ, bind=True, daemon=True) with pytest.raises( SuiteServiceFileError, match=r"Failed to find server's private key in " ): server.start(*PORT_RANGE, srv_prv_key_loc=pub) server.stop()
def test_server_cannot_start_when_server_private_key_cannot_be_loaded(): """Server should not be able to start when its private key file cannot be opened.""" server = ZMQSocketBase( zmq.REQ, suite=f"test_suite-{time()}", bind=True, daemon=True) with pytest.raises( SuiteServiceFileError, match=r"IO error opening server's private key from " ): server.start(*PORT_RANGE, srv_prv_key_loc="fake_dir/fake_location") server.stop()
def test_single_port(myflow, port_range): """Test server on a single port and port in use exception.""" context = zmq.Context() setup_keys(myflow) # auth keys are required for comms serv1 = ZMQSocketBase(zmq.REP, context=context, suite=myflow, bind=True) serv2 = ZMQSocketBase(zmq.REP, context=context, suite=myflow, bind=True) serv1._socket_bind(*port_range) port = serv1.port with pytest.raises(CylcError, match=r"Address already in use"): serv2._socket_bind(port, port) serv2.stop() serv1.stop() context.destroy()
def test_server_cannot_start_when_public_key_not_found_in_certificate_file(): """Server should not be able to start when its private key file does not contain the public key.""" with TemporaryDirectory() as keys: priv_key_loc = os.path.join(keys, "server.key_secret") open(priv_key_loc, 'a').close() server = ZMQSocketBase(zmq.REQ, bind=True, daemon=True) with pytest.raises( SuiteServiceFileError, match=r"Failed to find server's public key in " ): server.start(*PORT_RANGE, srv_prv_key_loc=priv_key_loc) server.stop()
def test_stop(): """Test socket/thread stop.""" create_auth_files('test_zmq_stop') # auth keys are required for comms barrier = Barrier(2, timeout=20) publisher = ZMQSocketBase(zmq.PUB, suite='test_zmq_stop', bind=True, barrier=barrier, threaded=True, daemon=True) publisher.start(*PORT_RANGE) # barrier.wait() doesn't seem to work properly here # so this workaround will do while publisher.barrier.n_waiting < 1: sleep(0.2) barrier.wait() assert not publisher.socket.closed assert publisher.thread.is_alive() publisher.stop() assert publisher.socket.closed assert not publisher.thread.is_alive()
def test_stop(myflow, port_range): """Test socket/thread stop.""" setup_keys(myflow) # auth keys are required for comms barrier = Barrier(2, timeout=20) publisher = ZMQSocketBase(zmq.PUB, workflow=myflow, bind=True, barrier=barrier, threaded=True, daemon=True) publisher.start(*port_range) # barrier.wait() doesn't seem to work properly here # so this workaround will do while publisher.barrier.n_waiting < 1: sleep(0.1) barrier.wait() assert not publisher.socket.closed assert publisher.thread.is_alive() publisher.stop() assert publisher.socket.closed assert not publisher.thread.is_alive()
def test_start(): """Test socket start.""" create_auth_files('test_zmq_start') # auth keys are required for comms barrier = Barrier(2, timeout=20) publisher = ZMQSocketBase(zmq.PUB, suite='test_zmq_start', bind=True, barrier=barrier, threaded=True, daemon=True) assert publisher.barrier.n_waiting == 0 assert publisher.loop is None assert publisher.port is None publisher.start(*PORT_RANGE) # barrier.wait() doesn't seem to work properly here # so this workaround will do while publisher.barrier.n_waiting < 1: sleep(0.2) assert barrier.wait() == 1 assert publisher.loop is not None assert publisher.port is not None publisher.stop()
def test_single_port(): """Test server on a single port and port in use exception.""" context = zmq.Context() create_auth_files('test_zmq') # auth keys are required for comms serv1 = ZMQSocketBase( zmq.REP, context=context, suite='test_zmq', bind=True) serv2 = ZMQSocketBase( zmq.REP, context=context, suite='test_zmq', bind=True) serv1._socket_bind(*PORT_RANGE) port = serv1.port with pytest.raises(CylcError, match=r"Address already in use") as exc: serv2._socket_bind(port, port) serv2.stop() serv1.stop() context.destroy()
def test_start(myflow, port_range): """Test socket start.""" setup_keys(myflow) # auth keys are required for comms barrier = Barrier(2, timeout=20) publisher = ZMQSocketBase(zmq.PUB, workflow=myflow, bind=True, barrier=barrier, threaded=True, daemon=True) assert publisher.barrier.n_waiting == 0 assert publisher.loop is None assert publisher.port is None publisher.start(*port_range) # barrier.wait() doesn't seem to work properly here # so this workaround will do while publisher.barrier.n_waiting < 1: sleep(0.2) assert barrier.wait() == 1 assert publisher.loop is not None assert publisher.port is not None publisher.stop()
def test_client_requires_valid_server_public_key_in_private_key_file(): """Client should not be able to connect to host/port without server public key.""" suite_name = f"test_suite-{time()}" port = random.choice(PORT_RANGE) client = ZMQSocketBase(zmq.REP, suite=suite_name) test_suite_srv_dir = get_suite_srv_dir(reg=suite_name) key_info = KeyInfo( KeyType.PRIVATE, KeyOwner.CLIENT, suite_srv_dir=test_suite_srv_dir) directory = os.path.expanduser("~/cylc-run") tmpdir = os.path.join(directory, suite_name) os.makedirs(key_info.key_path, exist_ok=True) _pub, _priv = zmq.auth.create_certificates(key_info.key_path, "client") with pytest.raises(ClientError, match=r"Failed to load the suite's public " r"key, so cannot connect."): client.start(HOST, port, srv_public_key_loc="fake_location") client.stop() rmtree(tmpdir, ignore_errors=True)