def _socket_connect(self, host, port, srv_public_key_loc=None): """Connect socket to stub.""" suite_srv_dir = get_suite_srv_dir(self.suite) if srv_public_key_loc is None: # Create new KeyInfo object for the server public key srv_pub_key_info = KeyInfo(KeyType.PUBLIC, KeyOwner.SERVER, suite_srv_dir=suite_srv_dir) else: srv_pub_key_info = KeyInfo(KeyType.PUBLIC, KeyOwner.SERVER, full_key_path=srv_public_key_loc) self.host = host self.port = port self.socket = self.context.socket(self.pattern) self._socket_options() client_priv_key_info = KeyInfo(KeyType.PRIVATE, KeyOwner.CLIENT, suite_srv_dir=suite_srv_dir) error_msg = "Failed to find user's private key, so cannot connect." try: client_public_key, client_priv_key = zmq.auth.load_certificate( client_priv_key_info.full_key_path) except (OSError, ValueError): raise ClientError(error_msg) if client_priv_key is None: # this can't be caught by exception raise ClientError(error_msg) self.socket.curve_publickey = client_public_key self.socket.curve_secretkey = client_priv_key # A client can only connect to the server if it knows its public key, # so we grab this from the location it was created on the filesystem: try: # 'load_certificate' will try to load both public & private keys # from a provided file but will return None, not throw an error, # for the latter item if not there (as for all public key files) # so it is OK to use; there is no method to load only the # public key. server_public_key = zmq.auth.load_certificate( srv_pub_key_info.full_key_path)[0] self.socket.curve_serverkey = server_public_key except (OSError, ValueError): # ValueError raised w/ no public key raise ClientError( "Failed to load the suite's public key, so cannot connect.") self.socket.connect(f'tcp://{host}:{port}')
async def async_request(self, command, args=None, timeout=None): """Send an asynchronous request using asyncio. Has the same arguments and return values as ``serial_request``. """ if timeout: timeout = float(timeout) timeout = (timeout * 1000 if timeout else None) or self.timeout if not args: args = {} # get secret for this request # assumes secret won't change during the request try: secret = self.secret() except cylc.flow.suite_srv_files_mgr.SuiteServiceFileError: raise ClientError('could not read suite passphrase') # send message msg = {'command': command, 'args': args} msg.update(self.header) LOG.debug('zmq:send %s' % msg) message = encrypt(msg, secret) self.socket.send_string(message) # receive response if self.poller.poll(timeout): res = await self.socket.recv() else: if self.timeout_handler: self.timeout_handler() raise ClientTimeout('Timeout waiting for server response.') if msg['command'] in PB_METHOD_MAP: response = {'data': res} else: try: response = decrypt(res.decode(), secret) except jose.exceptions.JWTError: raise ClientError( 'Could not decrypt response. Has the passphrase changed?') LOG.debug('zmq:recv %s' % response) try: return response['data'] except KeyError: error = response['error'] raise ClientError(error['message'], error.get('traceback'))
def get_location(suite: str, owner: str, host: str): """Extract host and port from a suite's contact file. NB: if it fails to load the suite contact file, it will exit. Args: suite (str): suite name owner (str): owner of the suite host (str): host name Returns: Tuple[str, int, int]: tuple with the host name and port numbers. Raises: ClientError: if the suite is not running. """ try: contact = load_contact_file(suite, owner, host) except SuiteServiceFileError: raise ClientError(f'Contact info not found for suite ' f'"{suite}", suite not running?') if not host: host = contact[ContactFileFields.HOST] host = get_fqdn_by_host(host) port = int(contact[ContactFileFields.PORT]) pub_port = int(contact[ContactFileFields.PUBLISH_PORT]) return host, port, pub_port
async def async_request(self, command, args=None, timeout=None): """Send asynchronous request via SSH. """ timeout = timeout if timeout is not None else self.timeout try: async with ascyncto(timeout): cmd, ssh_cmd, login_sh, cylc_path, msg = self.prepare_command( command, args, timeout) proc = _remote_cylc_cmd(cmd, host=self.host, stdin_str=msg, ssh_cmd=ssh_cmd, remote_cylc_path=cylc_path, ssh_login_shell=login_sh, capture_process=True) while True: if proc.poll() is not None: break await asyncio.sleep(self.SLEEP_INTERVAL) out, err = (f.decode() for f in proc.communicate()) return_code = proc.wait() if return_code: raise ClientError(err, f"return-code={return_code}") return json.loads(out) except asyncio.TimeoutError: raise ClientTimeout(f"Command exceeded the timeout {timeout}. " f"This could be due to network problems. " "Check the workflow log.")
def _socket_connect(self, host, port, srv_public_key_loc=None): """Connect socket to stub.""" if srv_public_key_loc is None: srv_public_key_loc = get_auth_item( UserFiles.Auth.SERVER_PUBLIC_KEY_CERTIFICATE, self.suite, content=False) self.host = host self.port = port self.socket = self.context.socket(self.pattern) self._socket_options() # check for, & create if nonexistent, user keys in the right location if not ensure_user_keys_exist(): raise ClientError("Unable to generate user authentication keys.") client_priv_keyfile = os.path.join( UserFiles.get_user_certificate_full_path(private=True), UserFiles.Auth.CLIENT_PRIVATE_KEY_CERTIFICATE) error_msg = "Failed to find user's private key, so cannot connect." try: client_public_key, client_priv_key = zmq.auth.load_certificate( client_priv_keyfile) except (OSError, ValueError): raise ClientError(error_msg) if client_priv_key is None: # this can't be caught by exception raise ClientError(error_msg) self.socket.curve_publickey = client_public_key self.socket.curve_secretkey = client_priv_key # A client can only connect to the server if it knows its public key, # so we grab this from the location it was created on the filesystem: try: # 'load_certificate' will try to load both public & private keys # from a provided file but will return None, not throw an error, # for the latter item if not there (as for all public key files) # so it is OK to use; there is no method to load only the # public key. server_public_key = zmq.auth.load_certificate( srv_public_key_loc)[0] self.socket.curve_serverkey = server_public_key except (OSError, ValueError): # ValueError raised w/ no public key raise ClientError( "Failed to load the suite's public key, so cannot connect.") self.socket.connect(f'tcp://{host}:{port}')
async def async_request(self, command, args=None, timeout=None, req_meta=None): """Send an asynchronous request using asyncio. Has the same arguments and return values as ``serial_request``. """ timeout = (float(timeout) * 1000 if timeout else None) or self.timeout if not args: args = {} # Note: we are using CurveZMQ to secure the messages (see # self.curve_auth, self.socket.curve_...key etc.). We have set up # public-key cryptography on the ZMQ messaging and sockets, so # there is no need to encrypt messages ourselves before sending. # send message msg = {'command': command, 'args': args} msg.update(self.header) # add the request metadata if req_meta: msg['meta'].update(req_meta) LOG.debug('zmq:send %s', msg) message = encode_(msg) self.socket.send_string(message) # receive response if self.poller.poll(timeout): res = await self.socket.recv() else: if callable(self.timeout_handler): self.timeout_handler() raise ClientTimeout( 'Timeout waiting for server response.' ' This could be due to network or server issues.' ' Check the workflow log.') if msg['command'] in PB_METHOD_MAP: response = {'data': res} else: response = decode_(res.decode()) LOG.debug('zmq:recv %s', response) try: return response['data'] except KeyError: error = response['error'] raise ClientError(error['message'], error.get('traceback'))
def send_request(self, command, args=None, timeout=None): """Send a request, using ssh. Determines ssh_cmd, cylc_path and login_shell settings from the contact file. Converts message to JSON and sends this to stdin. Executes the Cylc command, then deserialises the output. Use ``__call__`` to call this method. Args: command (str): The name of the endpoint to call. args (dict): Arguments to pass to the endpoint function. timeout (float): Override the default timeout (seconds). Raises: ClientError: Coverall, on error from function call Returns: object: Deserialized output from function called. """ # Set environment variable to determine the communication for use on # the scheduler os.environ["CLIENT_COMMS_METH"] = CommsMeth.SSH.value cmd = ["client"] if timeout: cmd += [f'comms_timeout={timeout}'] cmd += [self.workflow, command] contact = load_contact_file(self.workflow) ssh_cmd = contact[ContactFileFields.SCHEDULER_SSH_COMMAND] login_shell = contact[ContactFileFields.SCHEDULER_USE_LOGIN_SHELL] cylc_path = contact[ContactFileFields.SCHEDULER_CYLC_PATH] cylc_path = None if cylc_path == 'None' else cylc_path if not args: args = {} message = json.dumps(args) proc = _remote_cylc_cmd( cmd, host=self.host, stdin_str=message, ssh_cmd=ssh_cmd, remote_cylc_path=cylc_path, ssh_login_shell=login_shell, capture_process=True) out, err = (f.decode() for f in proc.communicate()) return_code = proc.wait() if return_code: raise ClientError(err, f"return-code={return_code}") return json.loads(out)
def cli_cmd(*cmd): """Issue a CLI command. Args: cmd: The command without the 'cylc' prefix'. Rasies: ClientError: In the event of mishap for consistency with the network client alternative. """ proc = Popen( # nosec (command constructed internally, no untrusted input) ['cylc', *cmd], stderr=PIPE, stdout=PIPE, text=True, ) out, err = proc.communicate() if proc.returncode != 0: raise ClientError(err)
def _timeout_handler(suite: str, host: str, port: Union[int, str]): """Handle the eventuality of a communication timeout with the suite. Args: suite (str): suite name host (str): host name port (Union[int, str]): port number Raises: ClientError: if the suite has already stopped. """ if suite is None: return # Cannot connect, perhaps suite is no longer running and is leaving # behind a contact file? try: SuiteSrvFilesManager().detect_old_contact_file(suite, (host, port)) except (AssertionError, SuiteServiceFileError): # * contact file not have matching (host, port) to suite proc # * old contact file exists and the suite process still alive return else: # the suite has stopped raise ClientError('Suite "%s" already stopped' % suite)