Пример #1
0
    def _socket_connect(self, host, port, srv_public_key_loc=None):
        """Connect socket to stub."""
        suite_srv_dir = get_suite_srv_dir(self.suite)
        if srv_public_key_loc is None:
            # Create new KeyInfo object for the server public key
            srv_pub_key_info = KeyInfo(KeyType.PUBLIC,
                                       KeyOwner.SERVER,
                                       suite_srv_dir=suite_srv_dir)

        else:
            srv_pub_key_info = KeyInfo(KeyType.PUBLIC,
                                       KeyOwner.SERVER,
                                       full_key_path=srv_public_key_loc)

        self.host = host
        self.port = port
        self.socket = self.context.socket(self.pattern)
        self._socket_options()

        client_priv_key_info = KeyInfo(KeyType.PRIVATE,
                                       KeyOwner.CLIENT,
                                       suite_srv_dir=suite_srv_dir)
        error_msg = "Failed to find user's private key, so cannot connect."
        try:
            client_public_key, client_priv_key = zmq.auth.load_certificate(
                client_priv_key_info.full_key_path)
        except (OSError, ValueError):
            raise ClientError(error_msg)
        if client_priv_key is None:  # this can't be caught by exception
            raise ClientError(error_msg)
        self.socket.curve_publickey = client_public_key
        self.socket.curve_secretkey = client_priv_key

        # A client can only connect to the server if it knows its public key,
        # so we grab this from the location it was created on the filesystem:
        try:
            # 'load_certificate' will try to load both public & private keys
            # from a provided file but will return None, not throw an error,
            # for the latter item if not there (as for all public key files)
            # so it is OK to use; there is no method to load only the
            # public key.
            server_public_key = zmq.auth.load_certificate(
                srv_pub_key_info.full_key_path)[0]
            self.socket.curve_serverkey = server_public_key
        except (OSError, ValueError):  # ValueError raised w/ no public key
            raise ClientError(
                "Failed to load the suite's public key, so cannot connect.")

        self.socket.connect(f'tcp://{host}:{port}')
Пример #2
0
    async def async_request(self, command, args=None, timeout=None):
        """Send an asynchronous request using asyncio.

        Has the same arguments and return values as ``serial_request``.

        """
        if timeout:
            timeout = float(timeout)
        timeout = (timeout * 1000 if timeout else None) or self.timeout
        if not args:
            args = {}

        # get secret for this request
        # assumes secret won't change during the request
        try:
            secret = self.secret()
        except cylc.flow.suite_srv_files_mgr.SuiteServiceFileError:
            raise ClientError('could not read suite passphrase')

        # send message
        msg = {'command': command, 'args': args}
        msg.update(self.header)
        LOG.debug('zmq:send %s' % msg)
        message = encrypt(msg, secret)
        self.socket.send_string(message)

        # receive response
        if self.poller.poll(timeout):
            res = await self.socket.recv()
        else:
            if self.timeout_handler:
                self.timeout_handler()
            raise ClientTimeout('Timeout waiting for server response.')

        if msg['command'] in PB_METHOD_MAP:
            response = {'data': res}
        else:
            try:
                response = decrypt(res.decode(), secret)
            except jose.exceptions.JWTError:
                raise ClientError(
                    'Could not decrypt response. Has the passphrase changed?')
        LOG.debug('zmq:recv %s' % response)

        try:
            return response['data']
        except KeyError:
            error = response['error']
            raise ClientError(error['message'], error.get('traceback'))
Пример #3
0
def get_location(suite: str, owner: str, host: str):
    """Extract host and port from a suite's contact file.

    NB: if it fails to load the suite contact file, it will exit.

    Args:
        suite (str): suite name
        owner (str): owner of the suite
        host (str): host name
    Returns:
        Tuple[str, int, int]: tuple with the host name and port numbers.
    Raises:
        ClientError: if the suite is not running.
    """
    try:
        contact = load_contact_file(suite, owner, host)
    except SuiteServiceFileError:
        raise ClientError(f'Contact info not found for suite '
                          f'"{suite}", suite not running?')

    if not host:
        host = contact[ContactFileFields.HOST]
    host = get_fqdn_by_host(host)

    port = int(contact[ContactFileFields.PORT])
    pub_port = int(contact[ContactFileFields.PUBLISH_PORT])
    return host, port, pub_port
Пример #4
0
 async def async_request(self, command, args=None, timeout=None):
     """Send asynchronous request via SSH.
     """
     timeout = timeout if timeout is not None else self.timeout
     try:
         async with ascyncto(timeout):
             cmd, ssh_cmd, login_sh, cylc_path, msg = self.prepare_command(
                 command, args, timeout)
             proc = _remote_cylc_cmd(cmd,
                                     host=self.host,
                                     stdin_str=msg,
                                     ssh_cmd=ssh_cmd,
                                     remote_cylc_path=cylc_path,
                                     ssh_login_shell=login_sh,
                                     capture_process=True)
             while True:
                 if proc.poll() is not None:
                     break
                 await asyncio.sleep(self.SLEEP_INTERVAL)
             out, err = (f.decode() for f in proc.communicate())
             return_code = proc.wait()
             if return_code:
                 raise ClientError(err, f"return-code={return_code}")
             return json.loads(out)
     except asyncio.TimeoutError:
         raise ClientTimeout(f"Command exceeded the timeout {timeout}. "
                             f"This could be due to network problems. "
                             "Check the workflow log.")
Пример #5
0
    def _socket_connect(self, host, port, srv_public_key_loc=None):
        """Connect socket to stub."""
        if srv_public_key_loc is None:
            srv_public_key_loc = get_auth_item(
                UserFiles.Auth.SERVER_PUBLIC_KEY_CERTIFICATE,
                self.suite,
                content=False)

        self.host = host
        self.port = port
        self.socket = self.context.socket(self.pattern)
        self._socket_options()

        # check for, & create if nonexistent, user keys in the right location
        if not ensure_user_keys_exist():
            raise ClientError("Unable to generate user authentication keys.")

        client_priv_keyfile = os.path.join(
            UserFiles.get_user_certificate_full_path(private=True),
            UserFiles.Auth.CLIENT_PRIVATE_KEY_CERTIFICATE)
        error_msg = "Failed to find user's private key, so cannot connect."
        try:
            client_public_key, client_priv_key = zmq.auth.load_certificate(
                client_priv_keyfile)
        except (OSError, ValueError):
            raise ClientError(error_msg)
        if client_priv_key is None:  # this can't be caught by exception
            raise ClientError(error_msg)
        self.socket.curve_publickey = client_public_key
        self.socket.curve_secretkey = client_priv_key

        # A client can only connect to the server if it knows its public key,
        # so we grab this from the location it was created on the filesystem:
        try:
            # 'load_certificate' will try to load both public & private keys
            # from a provided file but will return None, not throw an error,
            # for the latter item if not there (as for all public key files)
            # so it is OK to use; there is no method to load only the
            # public key.
            server_public_key = zmq.auth.load_certificate(
                srv_public_key_loc)[0]
            self.socket.curve_serverkey = server_public_key
        except (OSError, ValueError):  # ValueError raised w/ no public key
            raise ClientError(
                "Failed to load the suite's public key, so cannot connect.")

        self.socket.connect(f'tcp://{host}:{port}')
Пример #6
0
    async def async_request(self,
                            command,
                            args=None,
                            timeout=None,
                            req_meta=None):
        """Send an asynchronous request using asyncio.

        Has the same arguments and return values as ``serial_request``.

        """
        timeout = (float(timeout) * 1000 if timeout else None) or self.timeout
        if not args:
            args = {}

        # Note: we are using CurveZMQ to secure the messages (see
        # self.curve_auth, self.socket.curve_...key etc.). We have set up
        # public-key cryptography on the ZMQ messaging and sockets, so
        # there is no need to encrypt messages ourselves before sending.

        # send message
        msg = {'command': command, 'args': args}
        msg.update(self.header)
        # add the request metadata
        if req_meta:
            msg['meta'].update(req_meta)
        LOG.debug('zmq:send %s', msg)
        message = encode_(msg)
        self.socket.send_string(message)

        # receive response
        if self.poller.poll(timeout):
            res = await self.socket.recv()
        else:
            if callable(self.timeout_handler):
                self.timeout_handler()
            raise ClientTimeout(
                'Timeout waiting for server response.'
                ' This could be due to network or server issues.'
                ' Check the workflow log.')

        if msg['command'] in PB_METHOD_MAP:
            response = {'data': res}
        else:
            response = decode_(res.decode())
        LOG.debug('zmq:recv %s', response)

        try:
            return response['data']
        except KeyError:
            error = response['error']
            raise ClientError(error['message'], error.get('traceback'))
Пример #7
0
    def send_request(self, command, args=None, timeout=None):
        """Send a request, using ssh.

        Determines ssh_cmd, cylc_path and login_shell settings from the contact
        file.

        Converts message to JSON and sends this to stdin. Executes the Cylc
        command, then deserialises the output.

        Use ``__call__`` to call this method.

        Args:
            command (str): The name of the endpoint to call.
            args (dict): Arguments to pass to the endpoint function.
            timeout (float): Override the default timeout (seconds).
        Raises:
            ClientError: Coverall, on error from function call
        Returns:
            object: Deserialized output from function called.
        """
        # Set environment variable to determine the communication for use on
        # the scheduler
        os.environ["CLIENT_COMMS_METH"] = CommsMeth.SSH.value
        cmd = ["client"]
        if timeout:
            cmd += [f'comms_timeout={timeout}']
        cmd += [self.workflow, command]
        contact = load_contact_file(self.workflow)
        ssh_cmd = contact[ContactFileFields.SCHEDULER_SSH_COMMAND]
        login_shell = contact[ContactFileFields.SCHEDULER_USE_LOGIN_SHELL]
        cylc_path = contact[ContactFileFields.SCHEDULER_CYLC_PATH]
        cylc_path = None if cylc_path == 'None' else cylc_path
        if not args:
            args = {}
        message = json.dumps(args)
        proc = _remote_cylc_cmd(
            cmd,
            host=self.host,
            stdin_str=message,
            ssh_cmd=ssh_cmd,
            remote_cylc_path=cylc_path,
            ssh_login_shell=login_shell,
            capture_process=True)

        out, err = (f.decode() for f in proc.communicate())
        return_code = proc.wait()
        if return_code:
            raise ClientError(err, f"return-code={return_code}")
        return json.loads(out)
Пример #8
0
def cli_cmd(*cmd):
    """Issue a CLI command.

    Args:
        cmd:
            The command without the 'cylc' prefix'.

    Rasies:
        ClientError:
            In the event of mishap for consistency with the network
            client alternative.

    """
    proc = Popen(  # nosec (command constructed internally, no untrusted input)
        ['cylc', *cmd],
        stderr=PIPE,
        stdout=PIPE,
        text=True,
    )
    out, err = proc.communicate()
    if proc.returncode != 0:
        raise ClientError(err)
Пример #9
0
    def _timeout_handler(suite: str, host: str, port: Union[int, str]):
        """Handle the eventuality of a communication timeout with the suite.

        Args:
            suite (str): suite name
            host (str): host name
            port (Union[int, str]): port number
        Raises:
            ClientError: if the suite has already stopped.
        """
        if suite is None:
            return
        # Cannot connect, perhaps suite is no longer running and is leaving
        # behind a contact file?
        try:
            SuiteSrvFilesManager().detect_old_contact_file(suite, (host, port))
        except (AssertionError, SuiteServiceFileError):
            # * contact file not have matching (host, port) to suite proc
            # * old contact file exists and the suite process still alive
            return
        else:
            # the suite has stopped
            raise ClientError('Suite "%s" already stopped' % suite)