Example #1
0
def _scan1_impl(conn, timeout, my_uuid):
    """Connect to host:port to get suite identify."""
    srv_files_mgr = SuiteSrvFilesManager()
    while True:
        if not conn.poll(SLEEP_INTERVAL):
            continue
        item = conn.recv()
        if item == MSG_QUIT:
            break
        host, port = item
        host_anon = host
        if is_remote_host(host):
            host_anon = get_host_ip_by_name(host)  # IP reduces DNS traffic
        client = SuiteIdClientAnon(
            None, host=host_anon, port=port, my_uuid=my_uuid, timeout=timeout)
        try:
            result = client.identify()
        except ConnectionTimeout as exc:
            conn.send((host, port, MSG_TIMEOUT))
        except ConnectionError as exc:
            conn.send((host, port, None))
        else:
            owner = result.get('owner')
            name = result.get('name')
            states = result.get('states', None)
            if cylc.flags.debug:
                print >> sys.stderr, '   suite:', name, owner
            if states is None:
                # This suite keeps its state info private.
                # Try again with the passphrase if I have it.
                try:
                    pphrase = srv_files_mgr.get_auth_item(
                        srv_files_mgr.FILE_BASE_PASSPHRASE, name, owner, host,
                        content=True)
                except SuiteServiceFileError:
                    pass
                else:
                    if pphrase:
                        client = SuiteIdClient(
                            name, owner=owner, host=host, port=port,
                            my_uuid=my_uuid, timeout=timeout)
                        try:
                            result = client.identify()
                        except ConnectionError as exc:
                            # Nope (private suite, wrong passphrase).
                            if cylc.flags.debug:
                                print >> sys.stderr, '    (wrong passphrase)'
                        else:
                            if cylc.flags.debug:
                                print >> sys.stderr, (
                                    '    (got states with passphrase)')
            conn.send((host, port, result))
    conn.close()
Example #2
0
class SuiteRuntimeServiceClient(object):
    """Client for calling the HTTP(S) API of running suites."""

    ANON_AUTH = ('anon', NO_PASSPHRASE, False)
    COMPAT_MAP = {  # Limited pre-7.5.0 API compat mapping
        'clear_broadcast': {0: 'broadcast/clear'},
        'expire_broadcast': {0: 'broadcast/expire'},
        'get_broadcast': {0: 'broadcast/get'},
        'get_info': {0: 'info/'},
        'get_suite_state_summary': {0: 'state/get_state_summary'},
        'put_broadcast': {0: 'broadcast/put'},
        'put_command': {0: 'command/'},
        'put_ext_trigger': {0: 'ext-trigger/put'},
        'put_messages': {0: 'message/put', 1: 'put_message'},
    }
    ERROR_NO_HTTPS_SUPPORT = (
        "ERROR: server has no HTTPS support," +
        " configure your global.rc file to use HTTP : {0}\n"
    )
    METHOD = 'POST'
    METHOD_POST = 'POST'
    METHOD_GET = 'GET'

    MSG_RETRY_INTVL = 5.0
    MSG_MAX_TRIES = 7
    MSG_TIMEOUT = 30.0

    def __init__(
            self, suite, owner=None, host=None, port=None, timeout=None,
            my_uuid=None, print_uuid=False, auth=None):
        self.suite = suite
        if not owner:
            owner = get_user()
        self.owner = owner
        self.host = host
        if self.host and self.host.split('.')[0] == 'localhost':
            self.host = get_host()
        elif self.host and '.' not in self.host:  # Not IP and no domain
            self.host = get_fqdn_by_host(self.host)
        self.port = port
        self.srv_files_mgr = SuiteSrvFilesManager()
        if timeout is not None:
            timeout = float(timeout)
        self.timeout = timeout
        self.my_uuid = my_uuid or uuid4()
        if print_uuid:
            sys.stderr.write('%s\n' % self.my_uuid)

        self.prog_name = os.path.basename(sys.argv[0])
        self.auth = auth
        self.session = None
        self.comms1 = {}  # content in primary contact file
        self.comms2 = {}  # content in extra contact file, e.g. contact via ssh

    def _compat(self, name, default=None):
        """Return server function name.

        Handle back-compat for pre-7.5.0 if relevant.
        """
        # Need to load contact info here to get API version.
        self._load_contact_info()
        if default is None:
            default = name
        return self.COMPAT_MAP[name].get(
            self.comms1.get(self.srv_files_mgr.KEY_API), default)

    def clear_broadcast(self, payload):
        """Clear broadcast runtime task settings."""
        return self._call_server(
            self._compat('clear_broadcast'), payload=payload)

    def expire_broadcast(self, **kwargs):
        """Expire broadcast runtime task settings."""
        return self._call_server(self._compat('expire_broadcast'), **kwargs)

    def get_broadcast(self, **kwargs):
        """Return broadcast settings."""
        return self._call_server(
            self._compat('get_broadcast'), method=self.METHOD_GET, **kwargs)

    def get_info(self, command, **kwargs):
        """Return suite info."""
        return self._call_server(
            self._compat('get_info', default='') + command,
            method=self.METHOD_GET, **kwargs)

    def get_latest_state(self, full_mode=False):
        """Return latest state of the suite (for the GUI)."""
        self._load_contact_info()
        if self.comms1.get(self.srv_files_mgr.KEY_API) == 0:
            # Basic compat for pre-7.5.0 suites
            # Full mode only.
            # Error content/size not supported.
            # Report made-up main loop interval of 5.0 seconds.
            return {
                'cylc_version': self.get_info('get_cylc_version'),
                'full_mode': full_mode,
                'summary': self.get_suite_state_summary(),
                'ancestors': self.get_info('get_first_parent_ancestors'),
                'ancestors_pruned': self.get_info(
                    'get_first_parent_ancestors', pruned=True),
                'descendants': self.get_info('get_first_parent_descendants'),
                'err_content': '',
                'err_size': 0,
                'mean_main_loop_interval': 5.0}
        else:
            return self._call_server(
                'get_latest_state',
                method=self.METHOD_GET, full_mode=full_mode)

    def get_suite_state_summary(self):
        """Return the global, task, and family summary data structures."""
        return utf8_enforce(self._call_server(
            self._compat('get_suite_state_summary'), method=self.METHOD_GET))

    def identify(self):
        """Return suite identity."""
        # Note on compat: Suites on 7.6.0 or above can just call "identify",
        # but has compat for "id/identity".
        return self._call_server('id/identify', method=self.METHOD_GET)

    def put_broadcast(self, payload):
        """Put/set broadcast runtime task settings."""
        return self._call_server(
            self._compat('put_broadcast'), payload=payload)

    def put_command(self, command, **kwargs):
        """Invoke suite command."""
        return self._call_server(
            self._compat('put_command', default='') + command, **kwargs)

    def put_ext_trigger(self, event_message, event_id):
        """Put external trigger."""
        return self._call_server(
            self._compat('put_ext_trigger'),
            event_message=event_message, event_id=event_id)

    def put_messages(self, payload):
        """Send task messages to suite server program.

        Arguments:
            payload (dict):
                task_job (str): Task job as "CYCLE/TASK_NAME/SUBMIT_NUM".
                event_time (str): Event time as string.
                messages (list): List in the form [[severity, message], ...].
        """
        retry_intvl = float(self.comms1.get(
            self.srv_files_mgr.KEY_TASK_MSG_RETRY_INTVL,
            self.MSG_RETRY_INTVL))
        max_tries = int(self.comms1.get(
            self.srv_files_mgr.KEY_TASK_MSG_MAX_TRIES,
            self.MSG_MAX_TRIES))
        for i in range(1, max_tries + 1):  # 1..max_tries inclusive
            orig_timeout = self.timeout
            if self.timeout is None:
                self.timeout = self.MSG_TIMEOUT
            try:
                func_name = self._compat('put_messages')
                if func_name == 'put_messages':
                    results = self._call_server(func_name, payload=payload)
                elif func_name == 'put_message':  # API 1, 7.5.0 compat
                    cycle, name = payload['task_job'].split('/')[0:2]
                    for severity, message in payload['messages']:
                        results.append(self._call_server(
                            func_name, task_id='%s.%s' % (name, cycle),
                            severity=severity, message=message))
                else:  # API 0, pre-7.5.0 compat, priority instead of severity
                    cycle, name = payload['task_job'].split('/')[0:2]
                    for severity, message in payload['messages']:
                        results.append(self._call_server(
                            func_name, task_id='%s.%s' % (name, cycle),
                            priority=severity, message=message))
            except ClientInfoError:
                # Contact info file not found, suite probably not running.
                # Don't bother with retry, suite restart will poll any way.
                raise
            except ClientError as exc:
                now = get_current_time_string()
                sys.stderr.write(
                    "%s WARNING - Message send failed, try %s of %s: %s\n" % (
                        now, i, max_tries, exc))
                if i < max_tries:
                    sys.stderr.write(
                        "   retry in %s seconds, timeout is %s\n" % (
                            retry_intvl, self.timeout))
                    sleep(retry_intvl)
                    # Reset in case contact info or passphrase change
                    self.comms1 = {}
                    self.host = None
                    self.port = None
                    self.auth = None
            else:
                if i > 1:
                    # Continue to write to STDERR, so users can easily see that
                    # it has recovered from previous failures.
                    sys.stderr.write(
                        "%s INFO - Send message: try %s of %s succeeded\n" % (
                            get_current_time_string(), i, max_tries))
                return results
            finally:
                self.timeout = orig_timeout

    def reset(self):
        """Compat method, does nothing."""
        pass

    def signout(self):
        """Tell server to forget this client."""
        return self._call_server('signout')

    def _call_server(self, function, method=METHOD, payload=None, **kwargs):
        """Build server URL + call it"""
        if self.comms2:
            return self._call_server_via_comms2(function, payload, **kwargs)
        url = self._call_server_get_url(function, **kwargs)
        # Remove proxy settings from environment for now
        environ = {}
        for key in ("http_proxy", "https_proxy"):
            val = os.environ.pop(key, None)
            if val:
                environ[key] = val
        try:
            return self.call_server_impl(url, method, payload)
        finally:
            os.environ.update(environ)

    def _call_server_get_url(self, function, **kwargs):
        """Build request URL."""
        scheme = self.comms1.get(self.srv_files_mgr.KEY_COMMS_PROTOCOL)
        if scheme is None:
            # Use standard setting from global configuration
            scheme = glbl_cfg().get(['communication', 'method'])
        url = '%s://%s:%s/%s' % (
            scheme, self.host, self.port, function)
        # If there are any parameters left in the dict after popping,
        # append them to the url.
        if kwargs:
            import urllib
            params = urllib.urlencode(kwargs, doseq=True)
            url += "?" + params
        return url

    def call_server_impl(self, url, method, payload):
        """Determine whether to use requests or urllib2 to call suite API."""
        impl = self._call_server_impl_urllib2
        try:
            import requests
        except ImportError:
            pass
        else:
            if [int(_) for _ in requests.__version__.split(".")] >= [2, 4, 2]:
                impl = self._call_server_impl_requests
        try:
            return impl(url, method, payload)
        except ClientConnectError as exc:
            if self.suite is None:
                raise
            # Cannot connect, perhaps suite is no longer running and is leaving
            # behind a contact file?
            try:
                self.srv_files_mgr.detect_old_contact_file(
                    self.suite, (self.host, self.port))
            except (AssertionError, SuiteServiceFileError):
                raise exc
            else:
                # self.srv_files_mgr.detect_old_contact_file should delete left
                # behind contact file if the old suite process no longer
                # exists. Should be safe to report that the suite has stopped.
                raise ClientConnectError(exc.args[0], exc.STOPPED % self.suite)

    def _call_server_impl_requests(self, url, method, payload):
        """Call server with "requests" library."""
        import requests
        from requests.packages.urllib3.exceptions import InsecureRequestWarning
        warnings.simplefilter("ignore", InsecureRequestWarning)
        if self.session is None:
            self.session = requests.Session()

        if method == self.METHOD_POST:
            session_method = self.session.post
        else:
            session_method = self.session.get
        scheme = url.split(':', 1)[0]  # Can use urlparse?
        username, password, verify = self._get_auth(scheme)
        try:
            ret = session_method(
                url,
                json=payload,
                verify=verify,
                proxies={},
                headers=self._get_headers(),
                auth=requests.auth.HTTPDigestAuth(username, password),
                timeout=self.timeout
            )
        except requests.exceptions.SSLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                raise CylcError(
                    "Cannot issue commands over unsecured http.")
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectError(url, exc)
        except requests.exceptions.Timeout as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientTimeout(url, exc)
        except requests.exceptions.RequestException as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectError(url, exc)
        if ret.status_code == 401:
            access_desc = 'private'
            if self.auth == self.ANON_AUTH:
                access_desc = 'public'
            raise ClientDeniedError(url, self.prog_name, access_desc)
        if ret.status_code >= 400:
            exception_text = get_exception_from_html(ret.text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(ret.text)
        try:
            ret.raise_for_status()
        except requests.exceptions.HTTPError as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectedError(url, exc)
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])
        try:
            return ret.json()
        except ValueError:
            return ret.text

    def _call_server_impl_urllib2(self, url, method, payload):
        """Call server with "urllib2" library."""
        import json
        import urllib2
        import ssl
        unverified_context = getattr(ssl, '_create_unverified_context', None)
        if unverified_context is not None:
            ssl._create_default_https_context = unverified_context

        scheme = url.split(':', 1)[0]  # Can use urlparse?
        username, password = self._get_auth(scheme)[0:2]
        auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
        auth_manager.add_password(None, url, username, password)
        auth = urllib2.HTTPDigestAuthHandler(auth_manager)
        opener = urllib2.build_opener(auth, urllib2.HTTPSHandler())
        headers_list = self._get_headers().items()
        if payload:
            payload = json.dumps(payload)
            headers_list.append(('Accept', 'application/json'))
            json_headers = {'Content-Type': 'application/json',
                            'Content-Length': len(payload)}
        else:
            payload = None
            json_headers = {'Content-Length': 0}
        opener.addheaders = headers_list
        req = urllib2.Request(url, payload, json_headers)

        # This is an unpleasant monkey patch, but there isn't an
        # alternative. urllib2 uses POST if there is a data payload
        # but that is not the correct criterion.
        # The difference is basically that POST changes
        # server state and GET doesn't.
        req.get_method = lambda: method
        try:
            response = opener.open(req, timeout=self.timeout)
        except urllib2.URLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                raise CylcError(
                    "Cannot issue commands over unsecured http.")
            if cylc.flags.debug:
                traceback.print_exc()
            if "timed out" in str(exc):
                raise ClientTimeout(url, exc)
            else:
                raise ClientConnectError(url, exc)
        except Exception as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientError(url, exc)

        if response.getcode() == 401:
            access_desc = 'private'
            if self.auth == self.ANON_AUTH:
                access_desc = 'public'
            raise ClientDeniedError(url, self.prog_name, access_desc)
        response_text = response.read()
        if response.getcode() >= 400:
            exception_text = get_exception_from_html(response_text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(response_text)
            raise ClientConnectedError(
                url,
                "%s HTTP return code" % response.getcode())
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])

        try:
            return json.loads(response_text)
        except ValueError:
            return response_text

    def _call_server_via_comms2(self, function, payload, **kwargs):
        """Call server via "cylc client --use-ssh".

        Call "cylc client --use-ssh" using `subprocess.Popen`. Payload and
        arguments of the API method are serialized as JSON and are written to a
        temporary file, which is then used as the STDIN of the "cylc client"
        command. The external call here should be even safer than a direct
        HTTP(S) call, as it can be blocked by SSH before it even gets a chance
        to make the subsequent HTTP(S) call.

        Arguments:
            function (str): name of API method, argument 1 of "cylc client".
            payload (str): extra data or information for the API method.
            **kwargs (dict): arguments for the API method.
        """
        import json
        from cylc.remote import remote_cylc_cmd
        command = ["client", function, self.suite]
        if payload:
            kwargs["payload"] = payload
        if kwargs:
            from tempfile import TemporaryFile
            stdin = TemporaryFile()
            json.dump(kwargs, stdin)
            stdin.seek(0)
        else:
            # With stdin=None, `remote_cylc_cmd` will:
            # * Set stdin to open(os.devnull)
            # * Add `-n` to the SSH command
            stdin = None
        proc = remote_cylc_cmd(
            command, self.owner, self.host, capture=True,
            ssh_login_shell=(self.comms1.get(
                self.srv_files_mgr.KEY_SSH_USE_LOGIN_SHELL
            ) in ['True', 'true']),
            ssh_cylc=(r'%s/bin/cylc' % self.comms1.get(
                self.srv_files_mgr.KEY_DIR_ON_SUITE_HOST)
            ),
            stdin=stdin,
        )
        out = proc.communicate()[0]
        return_code = proc.wait()
        if return_code:
            from pipes import quote
            command_str = " ".join(quote(item) for item in command)
            raise ClientError(command_str, "return-code=%d" % return_code)
        return json.loads(out)

    def _get_auth(self, protocol):
        """Return a user/password Digest Auth."""
        if self.auth is None:
            self.auth = self.ANON_AUTH
            try:
                pphrase = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    self.suite, self.owner, self.host, content=True)
                if protocol == 'https':
                    verify = self.srv_files_mgr.get_auth_item(
                        self.srv_files_mgr.FILE_BASE_SSL_CERT,
                        self.suite, self.owner, self.host)
                else:
                    verify = False
            except SuiteServiceFileError:
                pass
            else:
                if pphrase and pphrase != NO_PASSPHRASE:
                    self.auth = ('cylc', pphrase, verify)
        return self.auth

    def _get_headers(self):
        """Return HTTP headers identifying the client."""
        user_agent_string = (
            "cylc/%s prog_name/%s uuid/%s" % (
                CYLC_VERSION, self.prog_name, self.my_uuid
            )
        )
        auth_info = "%s@%s" % (get_user(), get_host())
        return {"User-Agent": user_agent_string,
                "From": auth_info}

    def _load_contact_info(self):
        """Obtain suite owner, host, port info.

        Determine host and port using content in port file, unless already
        specified.
        """
        if self.host and self.port:
            return
        if self.port:
            # In case the contact file is corrupted, user can specify the port.
            self.host = get_host()
            return
        try:
            # Always trust the values in the contact file otherwise.
            self.comms1 = self.srv_files_mgr.load_contact_file(
                self.suite, self.owner, self.host)
            # Port inside "try" block, as it needs a type conversion
            self.port = int(self.comms1.get(self.srv_files_mgr.KEY_PORT))
        except (IOError, ValueError, SuiteServiceFileError):
            raise ClientInfoError(self.suite)
        else:
            # Check mismatch suite UUID
            env_suite = os.getenv(self.srv_files_mgr.KEY_NAME)
            env_uuid = os.getenv(self.srv_files_mgr.KEY_UUID)
            if (self.suite and env_suite and env_suite == self.suite and
                    env_uuid and
                    env_uuid != self.comms1.get(self.srv_files_mgr.KEY_UUID)):
                raise ClientInfoUUIDError(
                    env_uuid, self.comms1[self.srv_files_mgr.KEY_UUID])
            # All good
            self.host = self.comms1.get(self.srv_files_mgr.KEY_HOST)
            self.owner = self.comms1.get(self.srv_files_mgr.KEY_OWNER)
            if self.srv_files_mgr.KEY_API not in self.comms1:
                self.comms1[self.srv_files_mgr.KEY_API] = 0  # <=7.5.0 compat
        # Indirect comms settings
        self.comms2.clear()
        try:
            self.comms2.update(self.srv_files_mgr.load_contact_file(
                self.suite, self.owner, self.host,
                SuiteSrvFilesManager.FILE_BASE_CONTACT2))
        except SuiteServiceFileError:
            pass
Example #3
0
def _scan1_impl(conn, timeout, my_uuid):
    """Connect to host:port to get suite identify."""
    srv_files_mgr = SuiteSrvFilesManager()
    while True:
        if not conn.poll(SLEEP_INTERVAL):
            continue
        item = conn.recv()
        if item == MSG_QUIT:
            break
        host, port = item
        host_anon = host
        if is_remote_host(host):
            host_anon = get_host_ip_by_name(host)  # IP reduces DNS traffic
        client = SuiteIdClientAnon(None,
                                   host=host_anon,
                                   port=port,
                                   my_uuid=my_uuid,
                                   timeout=timeout)
        try:
            result = client.identify()
        except ConnectionTimeout as exc:
            conn.send((host, port, MSG_TIMEOUT))
        except (ConnectionError, SuiteStillInitialisingError) as exc:
            conn.send((host, port, None))
        else:
            owner = result.get('owner')
            name = result.get('name')
            states = result.get('states', None)
            if cylc.flags.debug:
                print >> sys.stderr, '   suite:', name, owner
            if states is None:
                # This suite keeps its state info private.
                # Try again with the passphrase if I have it.
                try:
                    pphrase = srv_files_mgr.get_auth_item(
                        srv_files_mgr.FILE_BASE_PASSPHRASE,
                        name,
                        owner,
                        host,
                        content=True)
                except SuiteServiceFileError:
                    pass
                else:
                    if pphrase:
                        client = SuiteIdClient(name,
                                               owner=owner,
                                               host=host,
                                               port=port,
                                               my_uuid=my_uuid,
                                               timeout=timeout)
                        try:
                            result = client.identify()
                        except SuiteStillInitialisingError as exc:
                            if cylc.flags.debug:
                                print >> sys.stderr, (
                                    '    (connected with passphrase,' +
                                    ' suite initialising)')
                        except ConnectionError as exc:
                            # Nope (private suite, wrong passphrase).
                            if cylc.flags.debug:
                                print >> sys.stderr, '    (wrong passphrase)'
                        else:
                            if cylc.flags.debug:
                                print >> sys.stderr, (
                                    '    (got states with passphrase)')
            conn.send((host, port, result))
    conn.close()
Example #4
0
class HTTPServer(object):
    """HTTP(S) server by cherrypy, for serving suite runtime API."""

    API = 1
    LOG_CONNECT_DENIED_TMPL = "[client-connect] DENIED %s@%s:%s %s"

    def __init__(self, suite):
        # Suite only needed for back-compat with old clients (see below):
        self.suite = suite
        self.engine = None
        self.port = None

        # Figure out the ports we are allowed to use.
        base_port = GLOBAL_CFG.get(['communication', 'base port'])
        max_ports = GLOBAL_CFG.get(
            ['communication', 'maximum number of ports'])
        self.ok_ports = range(int(base_port), int(base_port) + int(max_ports))
        random.shuffle(self.ok_ports)

        comms_options = GLOBAL_CFG.get(['communication', 'options'])

        # HTTP Digest Auth uses MD5 - pretty secure in this use case.
        # Extending it with extra algorithms is allowed, but won't be
        # supported by most browsers. requests and urllib2 are OK though.
        self.hash_algorithm = "MD5"
        if "SHA1" in comms_options:
            # Note 'SHA' rather than 'SHA1'.
            self.hash_algorithm = "SHA"

        self.srv_files_mgr = SuiteSrvFilesManager()
        self.comms_method = GLOBAL_CFG.get(['communication', 'method'])
        self.get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(
            {
                'cylc': self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    suite, content=True),
                'anon': NO_PASSPHRASE
            },
            algorithm=self.hash_algorithm)
        if self.comms_method == 'http':
            self.cert = None
            self.pkey = None
        else:  # if self.comms_method in [None, 'https']:
            try:
                self.cert = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_SSL_CERT, suite)
                self.pkey = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_SSL_PEM, suite)
            except SuiteServiceFileError:
                ERR.error("no HTTPS/OpenSSL support. Aborting...")
                raise CylcError("No HTTPS support. "
                                "Configure user's global.rc to use HTTP.")
        self.start()

    @cherrypy.expose
    def apiversion(self):
        """Return API version."""
        return str(self.API)

    @staticmethod
    def connect(schd):
        """Mount suite schedular object to the web server."""
        cherrypy.tree.mount(SuiteRuntimeService(schd), '/')
        # For back-compat with "scan"
        cherrypy.tree.mount(SuiteRuntimeService(schd), '/id')

    @staticmethod
    def disconnect(schd):
        """Disconnect obj from the web server."""
        del cherrypy.tree.apps['/%s/%s' % (schd.owner, schd.suite)]

    def get_port(self):
        """Return the web server port."""
        return self.port

    def shutdown(self):
        """Shutdown the web server."""
        if hasattr(self, "engine"):
            self.engine.exit()
            self.engine.block()

    def start(self):
        """Start quick web service."""
        # cherrypy.config["tools.encode.on"] = True
        # cherrypy.config["tools.encode.encoding"] = "utf-8"
        cherrypy.config["server.socket_host"] = get_host()
        cherrypy.config["engine.autoreload.on"] = False

        if self.comms_method == "https":
            # Setup SSL etc. Otherwise fail and exit.
            # Require connection method to be the same e.g HTTP/HTTPS matching.
            cherrypy.config['server.ssl_module'] = 'pyopenSSL'
            cherrypy.config['server.ssl_certificate'] = self.cert
            cherrypy.config['server.ssl_private_key'] = self.pkey

        cherrypy.config['log.screen'] = None
        key = binascii.hexlify(os.urandom(16))
        cherrypy.config.update({
            'tools.auth_digest.on': True,
            'tools.auth_digest.realm': self.suite,
            'tools.auth_digest.get_ha1': self.get_ha1,
            'tools.auth_digest.key': key,
            'tools.auth_digest.algorithm': self.hash_algorithm
        })
        cherrypy.tools.connect_log = cherrypy.Tool(
            'on_end_resource', self._report_connection_if_denied)
        cherrypy.config['tools.connect_log.on'] = True
        self.engine = cherrypy.engine
        for port in self.ok_ports:
            cherrypy.config["server.socket_port"] = port
            try:
                cherrypy.engine.start()
                cherrypy.engine.wait(cherrypy.engine.states.STARTED)
            except cherrypy.process.wspbus.ChannelFailures:
                if cylc.flags.debug:
                    traceback.print_exc()
                # We need to reinitialise the httpserver for each port attempt.
                cherrypy.server.httpserver = None
            else:
                if cherrypy.engine.state == cherrypy.engine.states.STARTED:
                    self.port = port
                    return
        raise Exception("No available ports")

    @staticmethod
    def _get_client_connection_denied():
        """Return whether a connection was denied."""
        if "Authorization" not in cherrypy.request.headers:
            # Probably just the initial HTTPS handshake.
            return False
        status = cherrypy.response.status
        if isinstance(status, basestring):
            return cherrypy.response.status.split()[0] in ["401", "403"]
        return cherrypy.response.status in [401, 403]

    def _report_connection_if_denied(self):
        """Log an (un?)successful connection attempt."""
        prog_name, user, host, uuid = _get_client_info()[1:]
        connection_denied = self._get_client_connection_denied()
        if connection_denied:
            LOG.warning(self.__class__.LOG_CONNECT_DENIED_TMPL % (
                user, host, prog_name, uuid))
Example #5
0
class SuiteRuntimeServiceClient(object):
    """Client for calling the HTTP(S) API of running suites."""

    ANON_AUTH = ('anon', NO_PASSPHRASE, False)
    COMPAT_MAP = {  # Limited pre-7.5.0 API compat mapping
        'clear_broadcast': {0: 'broadcast/clear'},
        'expire_broadcast': {0: 'broadcast/expire'},
        'get_broadcast': {0: 'broadcast/get'},
        'get_info': {0: 'info/'},
        'get_suite_state_summary': {0: 'state/get_state_summary'},
        'get_tasks_by_state': {0: 'state/get_tasks_by_state'},
        'put_broadcast': {0: 'broadcast/put'},
        'put_command': {0: 'command/'},
        'put_ext_trigger': {0: 'ext-trigger/put'},
        'put_message': {0: 'message/put'},
    }
    ERROR_NO_HTTPS_SUPPORT = (
        "ERROR: server has no HTTPS support," +
        " configure your global.rc file to use HTTP : {0}\n")
    METHOD = 'POST'
    METHOD_POST = 'POST'
    METHOD_GET = 'GET'

    def __init__(self,
                 suite,
                 owner=None,
                 host=None,
                 port=None,
                 timeout=None,
                 my_uuid=None,
                 print_uuid=False,
                 comms_protocol=None,
                 auth=None):
        self.suite = suite
        if not owner:
            owner = get_user()
        self.owner = owner
        self.host = host
        if self.host and self.host.split('.')[0] == 'localhost':
            self.host = get_host()
        elif self.host and '.' not in self.host:  # Not IP and no domain
            self.host = get_fqdn_by_host(self.host)
        self.port = port
        self.srv_files_mgr = SuiteSrvFilesManager()
        self.comms_protocol = comms_protocol
        if timeout is not None:
            timeout = float(timeout)
        self.timeout = timeout
        self.my_uuid = my_uuid or uuid4()
        if print_uuid:
            sys.stderr.write('%s\n' % self.my_uuid)

        self.prog_name = os.path.basename(sys.argv[0])
        self.auth = auth
        self.session = None
        self.api = None

    def _compat(self, name, default=None):
        """Return server function name.

        Handle back-compat for pre-7.5.0 if relevant.
        """
        # Need to load contact info here to get API version.
        self._load_contact_info()
        if default is None:
            default = name
        return self.COMPAT_MAP[name].get(self.api, default)

    def clear_broadcast(self, **kwargs):
        """Clear broadcast runtime task settings."""
        return self._call_server(self._compat('clear_broadcast'),
                                 payload=kwargs)

    def expire_broadcast(self, **kwargs):
        """Expire broadcast runtime task settings."""
        return self._call_server(self._compat('expire_broadcast'), **kwargs)

    def get_broadcast(self, **kwargs):
        """Return broadcast settings."""
        return self._call_server(self._compat('get_broadcast'),
                                 method=self.METHOD_GET,
                                 **kwargs)

    def get_info(self, command, **kwargs):
        """Return suite info."""
        return self._call_server(self._compat('get_info', default='') +
                                 command,
                                 method=self.METHOD_GET,
                                 **kwargs)

    def get_latest_state(self, full_mode):
        """Return latest state of the suite (for the GUI)."""
        self._load_contact_info()
        if self.api == 0:
            # Basic compat for pre-7.5.0 suites
            # Full mode only.
            # Error content/size not supported.
            # Report made-up main loop interval of 5.0 seconds.
            return {
                'cylc_version':
                self.get_info('get_cylc_version'),
                'full_mode':
                full_mode,
                'summary':
                self.get_suite_state_summary(),
                'ancestors':
                self.get_info('get_first_parent_ancestors'),
                'ancestors_pruned':
                self.get_info('get_first_parent_ancestors', pruned=True),
                'descendants':
                self.get_info('get_first_parent_descendants'),
                'err_content':
                '',
                'err_size':
                0,
                'mean_main_loop_interval':
                5.0
            }
        else:
            return self._call_server('get_latest_state',
                                     method=self.METHOD_GET,
                                     full_mode=full_mode)

    def get_suite_state_summary(self):
        """Return the global, task, and family summary data structures."""
        return utf8_enforce(
            self._call_server(self._compat('get_suite_state_summary'),
                              method=self.METHOD_GET))

    def get_tasks_by_state(self):
        """Returns a dict containing lists of tasks by state.

        Result in the form:
        {state: [(most_recent_time_string, task_name, point_string), ...]}
        """
        return self._call_server(self._compat('get_tasks_by_state'),
                                 method=self.METHOD_GET)

    def identify(self):
        """Return suite identity."""
        # Note on compat: Suites on 7.6.0 or above can just call "identify",
        # but has compat for "id/identity".
        return self._call_server('id/identify', method=self.METHOD_GET)

    def put_broadcast(self, **kwargs):
        """Put/set broadcast runtime task settings."""
        return self._call_server(self._compat('put_broadcast'), payload=kwargs)

    def put_command(self, command, **kwargs):
        """Invoke suite command."""
        return self._call_server(
            self._compat('put_command', default='') + command, **kwargs)

    def put_ext_trigger(self, event_message, event_id):
        """Put external trigger."""
        return self._call_server(self._compat('put_ext_trigger'),
                                 event_message=event_message,
                                 event_id=event_id)

    def put_message(self, task_id, severity, message):
        """Send task message."""
        func_name = self._compat('put_message')
        if func_name == 'put_message':
            return self._call_server(func_name,
                                     task_id=task_id,
                                     severity=severity,
                                     message=message)
        else:  # pre-7.5.0 API compat
            return self._call_server(func_name,
                                     task_id=task_id,
                                     priority=severity,
                                     message=message)

    def reset(self):
        """Compat method, does nothing."""
        pass

    def signout(self):
        """Tell server to forget this client."""
        return self._call_server('signout')

    def _call_server(self, function, method=METHOD, payload=None, **kwargs):
        """Build server URL + call it"""
        url = self._call_server_get_url(function, **kwargs)
        # Remove proxy settings from environment for now
        environ = {}
        for key in ("http_proxy", "https_proxy"):
            val = os.environ.pop(key, None)
            if val:
                environ[key] = val
        try:
            return self.call_server_impl(url, method, payload)
        finally:
            os.environ.update(environ)

    def _call_server_get_url(self, function, **kwargs):
        """Build request URL."""
        comms_protocol = self.comms_protocol
        if comms_protocol is None:
            # Use standard setting from global configuration
            comms_protocol = glbl_cfg().get(['communication', 'method'])
        url = '%s://%s:%s/%s' % (comms_protocol, self.host, self.port,
                                 function)
        # If there are any parameters left in the dict after popping,
        # append them to the url.
        if kwargs:
            import urllib
            params = urllib.urlencode(kwargs, doseq=True)
            url += "?" + params
        return url

    def call_server_impl(self, url, method, payload):
        """Determine whether to use requests or urllib2 to call suite API."""
        impl = self._call_server_impl_urllib2
        try:
            import requests
        except ImportError:
            pass
        else:
            if [int(_) for _ in requests.__version__.split(".")] >= [2, 4, 2]:
                impl = self._call_server_impl_requests
        try:
            return impl(url, method, payload)
        except ClientConnectError as exc:
            if self.suite is None:
                raise
            # Cannot connect, perhaps suite is no longer running and is leaving
            # behind a contact file?
            try:
                self.srv_files_mgr.detect_old_contact_file(
                    self.suite, (self.host, self.port))
            except (AssertionError, SuiteServiceFileError):
                raise exc
            else:
                # self.srv_files_mgr.detect_old_contact_file should delete left
                # behind contact file if the old suite process no longer
                # exists. Should be safe to report that the suite has stopped.
                raise ClientConnectError(exc.args[0], exc.STOPPED % self.suite)

    def _call_server_impl_requests(self, url, method, payload):
        """Call server with "requests" library."""
        import requests
        from requests.packages.urllib3.exceptions import InsecureRequestWarning
        warnings.simplefilter("ignore", InsecureRequestWarning)
        if self.session is None:
            self.session = requests.Session()

        if method == self.METHOD_POST:
            session_method = self.session.post
        else:
            session_method = self.session.get
        comms_protocol = url.split(':', 1)[0]  # Can use urlparse?
        username, password, verify = self._get_auth(comms_protocol)
        try:
            ret = session_method(url,
                                 json=payload,
                                 verify=verify,
                                 proxies={},
                                 headers=self._get_headers(),
                                 auth=requests.auth.HTTPDigestAuth(
                                     username, password),
                                 timeout=self.timeout)
        except requests.exceptions.SSLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                raise CylcError("Cannot issue commands over unsecured http.")
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectError(url, exc)
        except requests.exceptions.Timeout as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientTimeout(url, exc)
        except requests.exceptions.RequestException as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectError(url, exc)
        if ret.status_code == 401:
            access_desc = 'private'
            if self.auth == self.ANON_AUTH:
                access_desc = 'public'
            raise ClientDeniedError(url, self.prog_name, access_desc)
        if ret.status_code >= 400:
            exception_text = get_exception_from_html(ret.text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(ret.text)
        try:
            ret.raise_for_status()
        except requests.exceptions.HTTPError as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectedError(url, exc)
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(self.suite, self.owner,
                                                self.host, self.auth[1])
        try:
            return ret.json()
        except ValueError:
            return ret.text

    def _call_server_impl_urllib2(self, url, method, payload):
        """Call server with "urllib2" library."""
        import json
        import urllib2
        import ssl
        unverified_context = getattr(ssl, '_create_unverified_context', None)
        if unverified_context is not None:
            ssl._create_default_https_context = unverified_context

        comms_protocol = url.split(':', 1)[0]  # Can use urlparse?
        username, password = self._get_auth(comms_protocol)[0:2]
        auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
        auth_manager.add_password(None, url, username, password)
        auth = urllib2.HTTPDigestAuthHandler(auth_manager)
        opener = urllib2.build_opener(auth, urllib2.HTTPSHandler())
        headers_list = self._get_headers().items()
        if payload:
            payload = json.dumps(payload)
            headers_list.append(('Accept', 'application/json'))
            json_headers = {
                'Content-Type': 'application/json',
                'Content-Length': len(payload)
            }
        else:
            payload = None
            json_headers = {'Content-Length': 0}
        opener.addheaders = headers_list
        req = urllib2.Request(url, payload, json_headers)

        # This is an unpleasant monkey patch, but there isn't an
        # alternative. urllib2 uses POST if there is a data payload
        # but that is not the correct criterion.
        # The difference is basically that POST changes
        # server state and GET doesn't.
        req.get_method = lambda: method
        try:
            response = opener.open(req, timeout=self.timeout)
        except urllib2.URLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                raise CylcError("Cannot issue commands over unsecured http.")
            if cylc.flags.debug:
                traceback.print_exc()
            if "timed out" in str(exc):
                raise ClientTimeout(url, exc)
            else:
                raise ClientConnectError(url, exc)
        except Exception as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientError(url, exc)

        if response.getcode() == 401:
            access_desc = 'private'
            if self.auth == self.ANON_AUTH:
                access_desc = 'public'
            raise ClientDeniedError(url, self.prog_name, access_desc)
        response_text = response.read()
        if response.getcode() >= 400:
            exception_text = get_exception_from_html(response_text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(response_text)
            raise ClientConnectedError(
                url, "%s HTTP return code" % response.getcode())
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(self.suite, self.owner,
                                                self.host, self.auth[1])

        try:
            return json.loads(response_text)
        except ValueError:
            return response_text

    def _get_auth(self, protocol):
        """Return a user/password Digest Auth."""
        if self.auth is None:
            self.auth = self.ANON_AUTH
            try:
                pphrase = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    self.suite,
                    self.owner,
                    self.host,
                    content=True)
                if protocol == 'https':
                    verify = self.srv_files_mgr.get_auth_item(
                        self.srv_files_mgr.FILE_BASE_SSL_CERT, self.suite,
                        self.owner, self.host)
                else:
                    verify = False
            except SuiteServiceFileError:
                pass
            else:
                if pphrase and pphrase != NO_PASSPHRASE:
                    self.auth = ('cylc', pphrase, verify)
        return self.auth

    def _get_headers(self):
        """Return HTTP headers identifying the client."""
        user_agent_string = ("cylc/%s prog_name/%s uuid/%s" %
                             (CYLC_VERSION, self.prog_name, self.my_uuid))
        auth_info = "%s@%s" % (get_user(), get_host())
        return {"User-Agent": user_agent_string, "From": auth_info}

    def _load_contact_info(self):
        """Obtain suite owner, host, port info.

        Determine host and port using content in port file, unless already
        specified.
        """
        if self.host and self.port:
            return
        if self.port:
            # In case the contact file is corrupted, user can specify the port.
            self.host = get_host()
            return
        try:
            # Always trust the values in the contact file otherwise.
            data = self.srv_files_mgr.load_contact_file(
                self.suite, self.owner, self.host)
            # Port inside "try" block, as it needs a type conversion
            self.port = int(data.get(self.srv_files_mgr.KEY_PORT))
        except (IOError, ValueError, SuiteServiceFileError):
            raise ClientInfoError(self.suite)
        self.host = data.get(self.srv_files_mgr.KEY_HOST)
        self.owner = data.get(self.srv_files_mgr.KEY_OWNER)
        self.comms_protocol = data.get(self.srv_files_mgr.KEY_COMMS_PROTOCOL)
        try:
            self.api = int(data.get(self.srv_files_mgr.KEY_API))
        except (TypeError, ValueError):
            self.api = 0  # Assume cylc-7.5.0 or before
Example #6
0
class SuiteRuntimeServiceClient(object):
    """Client for calling the HTTP(S) API of running suites."""

    ANON_AUTH = ('anon', NO_PASSPHRASE, False)
    ERROR_NO_HTTPS_SUPPORT = (
        "ERROR: server has no HTTPS support," +
        " configure your global.rc file to use HTTP : {0}\n")
    METHOD = 'POST'
    METHOD_POST = 'POST'
    METHOD_GET = 'GET'

    def __init__(self,
                 suite,
                 owner=None,
                 host=None,
                 port=None,
                 timeout=None,
                 my_uuid=None,
                 print_uuid=False,
                 comms_protocol=None,
                 auth=None):
        self.suite = suite
        if not owner:
            owner = get_user()
        self.owner = owner
        self.host = host
        if self.host and self.host.split('.')[0] == 'localhost':
            self.host = get_host()
        elif self.host and '.' not in self.host:  # Not IP and no domain
            self.host = get_fqdn_by_host(self.host)
        self.port = port
        self.srv_files_mgr = SuiteSrvFilesManager()
        self.comms_protocol = comms_protocol
        if timeout is not None:
            timeout = float(timeout)
        self.timeout = timeout
        self.my_uuid = my_uuid or uuid4()
        if print_uuid:
            print >> sys.stderr, '%s' % self.my_uuid

        self.prog_name = os.path.basename(sys.argv[0])
        self.auth = auth

    def clear_broadcast(self, **kwargs):
        """Clear broadcast runtime task settings."""
        return self._call_server('clear_broadcast', payload=kwargs)

    def expire_broadcast(self, **kwargs):
        """Expire broadcast runtime task settings."""
        return self._call_server('expire_broadcast', **kwargs)

    def get_broadcast(self, **kwargs):
        """Return broadcast settings."""
        return self._call_server('get_broadcast',
                                 method=self.METHOD_GET,
                                 **kwargs)

    def get_info(self, command, *args, **kwargs):
        """Return suite info."""
        kwargs['method'] = self.METHOD_GET
        return self._call_server(command, *args, **kwargs)

    def get_latest_state(self, full_mode):
        """Return latest state of the suite (for the GUI)."""
        return self._call_server('get_latest_state',
                                 method=self.METHOD_GET,
                                 full_mode=full_mode)

    def get_suite_state_summary(self):
        """Return the global, task, and family summary data structures."""
        return utf8_enforce(
            self._call_server('get_suite_state_summary',
                              method=self.METHOD_GET))

    def get_tasks_by_state(self):
        """Returns a dict containing lists of tasks by state.

        Result in the form:
        {state: [(most_recent_time_string, task_name, point_string), ...]}
        """
        return self._call_server('get_tasks_by_state', method=self.METHOD_GET)

    def identify(self):
        """Return suite identity."""
        return self._call_server('identify', method=self.METHOD_GET)

    def put_broadcast(self, **kwargs):
        """Put/set broadcast runtime task settings."""
        return self._call_server('put_broadcast', payload=kwargs)

    def put_command(self, command, **kwargs):
        """Invoke suite command."""
        return self._call_server(command, **kwargs)

    def put_ext_trigger(self, event_message, event_id):
        """Put external trigger."""
        return self._call_server('put_ext_trigger',
                                 event_message=event_message,
                                 event_id=event_id)

    def put_message(self, task_id, priority, message):
        """Send task message."""
        return self._call_server('put_message',
                                 task_id=task_id,
                                 priority=priority,
                                 message=message)

    def reset(self, *args, **kwargs):
        """Compat method, does nothing."""
        pass

    def signout(self, *args, **kwargs):
        """Tell server to forget this client."""
        return self._call_server('signout')

    def _compile_request(self, func_dict, host, comms_protocol=None):
        """Build request URL."""
        payload = func_dict.pop("payload", None)
        method = func_dict.pop("method", self.METHOD)
        function = func_dict.pop("function", None)
        if comms_protocol is None:
            # Use standard setting from global configuration
            from cylc.cfgspec.globalcfg import GLOBAL_CFG
            comms_protocol = GLOBAL_CFG.get(['communication', 'method'])
        url = '%s://%s:%s/%s' % (comms_protocol, host, self.port, function)
        # If there are any parameters left in the dict after popping,
        # append them to the url.
        if func_dict:
            import urllib
            params = urllib.urlencode(func_dict, doseq=True)
            url += "?" + params
        return (method, url, payload, comms_protocol)

    def _call_server(self, *func_dicts, **fargs):
        """func_dict is a dictionary of command names (fnames)
        and arguments to that command"""
        # Deal with the case of one func_dict/function name passed
        # by converting them to the generic case: a dictionary of
        # a single function and its function arguments.
        if isinstance(func_dicts[0], str):
            func_dict = {"function": func_dicts[0]}
            func_dict.update(fargs)
        else:
            func_dict = None

        try:
            self._load_contact_info()
        except (IOError, ValueError, SuiteServiceFileError):
            raise ClientInfoError(self.suite)
        http_request_items = []
        try:
            # dictionary containing: url, payload, method
            http_request_items.append(
                self._compile_request(func_dict, self.host,
                                      self.comms_protocol))
        except (IndexError, ValueError, AttributeError):
            for f_dict in func_dicts:
                http_request_items.append(
                    self._compile_request(f_dict, self.host,
                                          self.comms_protocol))
        # Remove proxy settings from environment for now
        environ = {}
        for key in ("http_proxy", "https_proxy"):
            val = os.environ.pop(key, None)
            if val:
                environ[key] = val
        # Returns a list of http returns from the requests
        try:
            return self.call_server_impl(http_request_items)
        finally:
            os.environ.update(environ)

    def call_server_impl(self, http_request_items):
        """Determine whether to use requests or urllib2 to call suite API."""
        method = self._call_server_impl_urllib2
        try:
            import requests
        except ImportError:
            pass
        else:
            if [int(_) for _ in requests.__version__.split(".")] >= [2, 4, 2]:
                method = self._call_server_impl_requests
        try:
            return method(http_request_items)
        except ClientConnectError as exc:
            if self.suite is None:
                raise
            # Cannot connect, perhaps suite is no longer running and is leaving
            # behind a contact file?
            try:
                self.srv_files_mgr.detect_old_contact_file(
                    self.suite, (self.host, self.port))
            except (AssertionError, SuiteServiceFileError):
                raise exc
            else:
                # self.srv_files_mgr.detect_old_contact_file should delete left
                # behind contact file if the old suite process no longer
                # exists. Should be safe to report that the suite has stopped.
                raise ClientConnectError(exc.args[0], exc.STOPPED % self.suite)

    def _call_server_impl_requests(self, http_request_items):
        """Call server with "requests" library."""
        import requests
        from requests.packages.urllib3.exceptions import InsecureRequestWarning
        warnings.simplefilter("ignore", InsecureRequestWarning)
        if not hasattr(self, "session"):
            self.session = requests.Session()

        http_return_items = []
        for method, url, payload, comms_protocol in http_request_items:
            if method is None:
                method = self.METHOD
            if method == self.METHOD_POST:
                session_method = self.session.post
            else:
                session_method = self.session.get
            username, password, verify = self._get_auth(comms_protocol)
            try:
                ret = session_method(url,
                                     json=payload,
                                     verify=verify,
                                     proxies={},
                                     headers=self._get_headers(),
                                     auth=requests.auth.HTTPDigestAuth(
                                         username, password),
                                     timeout=self.timeout)
            except requests.exceptions.SSLError as exc:
                if "unknown protocol" in str(exc) and url.startswith("https:"):
                    # Server is using http rather than https, for some reason.
                    sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                    raise CylcError(
                        "Cannot issue commands over unsecured http.")
                if cylc.flags.debug:
                    traceback.print_exc()
                raise ClientConnectError(url, exc)
            except requests.exceptions.Timeout as exc:
                if cylc.flags.debug:
                    traceback.print_exc()
                raise ClientTimeout(url, exc)
            except requests.exceptions.RequestException as exc:
                if cylc.flags.debug:
                    traceback.print_exc()
                raise ClientConnectError(url, exc)
            if ret.status_code == 401:
                access_desc = 'private'
                if self.auth == self.ANON_AUTH:
                    access_desc = 'public'
                raise ClientDeniedError(url, self.prog_name, access_desc)
            if ret.status_code >= 400:
                exception_text = get_exception_from_html(ret.text)
                if exception_text:
                    sys.stderr.write(exception_text)
                else:
                    sys.stderr.write(ret.text)
            try:
                ret.raise_for_status()
            except requests.exceptions.HTTPError as exc:
                if cylc.flags.debug:
                    traceback.print_exc()
                raise ClientConnectedError(url, exc)
            if self.auth and self.auth[1] != NO_PASSPHRASE:
                self.srv_files_mgr.cache_passphrase(self.suite, self.owner,
                                                    self.host, self.auth[1])
            try:
                ret = ret.json()
                http_return_items.append(ret)
            except ValueError:
                ret = ret.text
                http_return_items.append(ret)
        # Return a single http return or a list of them if multiple
        return (http_return_items
                if len(http_return_items) > 1 else http_return_items[0])

    def _call_server_impl_urllib2(self, http_request_items):
        """Call server with "urllib2" library."""
        import json
        import urllib2
        import ssl
        if hasattr(ssl, '_create_unverified_context'):
            ssl._create_default_https_context = ssl._create_unverified_context

        http_return_items = []
        for method, url, payload, comms_protocol in http_request_items:
            if method is None:
                method = self.METHOD
            username, password = self._get_auth(comms_protocol)[0:2]
            auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
            auth_manager.add_password(None, url, username, password)
            auth = urllib2.HTTPDigestAuthHandler(auth_manager)
            opener = urllib2.build_opener(auth, urllib2.HTTPSHandler())
            headers_list = self._get_headers().items()
            if payload:
                payload = json.dumps(payload)
                headers_list.append(('Accept', 'application/json'))
                json_headers = {
                    'Content-Type': 'application/json',
                    'Content-Length': len(payload)
                }
            else:
                payload = None
                json_headers = {'Content-Length': 0}
            opener.addheaders = headers_list
            req = urllib2.Request(url, payload, json_headers)

            # This is an unpleasant monkey patch, but there isn't an
            # alternative. urllib2 uses POST if there is a data payload
            # but that is not the correct criterion.
            # The difference is basically that POST changes
            # server state and GET doesn't.
            req.get_method = lambda: method
            try:
                response = opener.open(req, timeout=self.timeout)
            except urllib2.URLError as exc:
                if "unknown protocol" in str(exc) and url.startswith("https:"):
                    # Server is using http rather than https, for some reason.
                    sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                    raise CylcError(
                        "Cannot issue commands over unsecured http.")
                if cylc.flags.debug:
                    traceback.print_exc()
                if "timed out" in str(exc):
                    raise ClientTimeout(url, exc)
                else:
                    raise ClientConnectError(url, exc)
            except Exception as exc:
                if cylc.flags.debug:
                    traceback.print_exc()
                raise ClientError(url, exc)

            if response.getcode() == 401:
                access_desc = 'private'
                if self.auth == self.ANON_AUTH:
                    access_desc = 'public'
                raise ClientDeniedError(url, self.prog_name, access_desc)
            response_text = response.read()
            if response.getcode() >= 400:
                exception_text = get_exception_from_html(response_text)
                if exception_text:
                    sys.stderr.write(exception_text)
                else:
                    sys.stderr.write(response_text)
                raise ClientConnectedError(
                    url, "%s HTTP return code" % response.getcode())
            if self.auth and self.auth[1] != NO_PASSPHRASE:
                self.srv_files_mgr.cache_passphrase(self.suite, self.owner,
                                                    self.host, self.auth[1])

            try:
                http_return_items.append(json.loads(response_text))
            except ValueError:
                http_return_items.append(response_text)
        # Return a single http return or a list of them if multiple
        return (http_return_items
                if len(http_return_items) > 1 else http_return_items[0])

    def _get_auth(self, protocol):
        """Return a user/password Digest Auth."""
        if self.auth is None:
            self.auth = self.ANON_AUTH
            try:
                pphrase = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    self.suite,
                    self.owner,
                    self.host,
                    content=True)
                if protocol == 'https':
                    verify = self.srv_files_mgr.get_auth_item(
                        self.srv_files_mgr.FILE_BASE_SSL_CERT, self.suite,
                        self.owner, self.host)
                else:
                    verify = False
            except SuiteServiceFileError:
                pass
            else:
                if pphrase and pphrase != NO_PASSPHRASE:
                    self.auth = ('cylc', pphrase, verify)
        return self.auth

    def _get_headers(self):
        """Return HTTP headers identifying the client."""
        user_agent_string = ("cylc/%s prog_name/%s uuid/%s" %
                             (CYLC_VERSION, self.prog_name, self.my_uuid))
        auth_info = "%s@%s" % (get_user(), get_host())
        return {"User-Agent": user_agent_string, "From": auth_info}

    def _load_contact_info(self):
        """Obtain suite owner, host, port info.

        Determine host and port using content in port file, unless already
        specified.
        """
        if self.host and self.port:
            return
        if self.port:
            # In case the contact file is corrupted, user can specify the port.
            self.host = get_host()
            return
        # Always trust the values in the contact file otherwise.
        data = self.srv_files_mgr.load_contact_file(self.suite, self.owner,
                                                    self.host)
        self.host = data.get(self.srv_files_mgr.KEY_HOST)
        self.port = int(data.get(self.srv_files_mgr.KEY_PORT))
        self.owner = data.get(self.srv_files_mgr.KEY_OWNER)
        self.comms_protocol = data.get(self.srv_files_mgr.KEY_COMMS_PROTOCOL)
Example #7
0
class SuiteRuntimeServiceClient(object):
    """Client for calling the HTTP(S) API of running suites."""

    ANON_AUTH = ('anon', NO_PASSPHRASE, False)
    COMPAT_MAP = {  # Limited pre-7.5.0 API compat mapping
        'clear_broadcast': {0: 'broadcast/clear'},
        'expire_broadcast': {0: 'broadcast/expire'},
        'get_broadcast': {0: 'broadcast/get'},
        'get_info': {0: 'info/'},
        'get_suite_state_summary': {0: 'state/get_state_summary'},
        'put_broadcast': {0: 'broadcast/put'},
        'put_command': {0: 'command/'},
        'put_ext_trigger': {0: 'ext-trigger/put'},
        'put_messages': {0: 'message/put', 1: 'put_message'},
    }
    ERROR_NO_HTTPS_SUPPORT = (
        "ERROR: server has no HTTPS support," +
        " configure your global.rc file to use HTTP : {0}\n"
    )
    METHOD = 'POST'
    METHOD_POST = 'POST'
    METHOD_GET = 'GET'

    MSG_RETRY_INTVL = 5.0
    MSG_MAX_TRIES = 7
    MSG_TIMEOUT = 30.0

    def __init__(
            self, suite, owner=None, host=None, port=None, timeout=None,
            my_uuid=None, print_uuid=False, auth=None):
        self.suite = suite
        if not owner:
            owner = get_user()
        self.owner = owner
        self.host = host
        if self.host and self.host.split('.')[0] == 'localhost':
            self.host = get_host()
        elif self.host and '.' not in self.host:  # Not IP and no domain
            self.host = get_fqdn_by_host(self.host)
        self.port = port
        self.srv_files_mgr = SuiteSrvFilesManager()
        if timeout is not None:
            timeout = float(timeout)
        self.timeout = timeout
        self.my_uuid = my_uuid or uuid4()
        if print_uuid:
            sys.stderr.write('%s\n' % self.my_uuid)

        self.prog_name = os.path.basename(sys.argv[0])
        self.auth = auth
        self.session = None
        self.comms1 = {}  # content in primary contact file
        self.comms2 = {}  # content in extra contact file, e.g. contact via ssh

    def _compat(self, name, default=None):
        """Return server function name.

        Handle back-compat for pre-7.5.0 if relevant.
        """
        # Need to load contact info here to get API version.
        self._load_contact_info()
        if default is None:
            default = name
        return self.COMPAT_MAP[name].get(
            self.comms1.get(self.srv_files_mgr.KEY_API), default)

    def clear_broadcast(self, payload):
        """Clear broadcast runtime task settings."""
        return self._call_server(
            self._compat('clear_broadcast'), payload=payload)

    def expire_broadcast(self, **kwargs):
        """Expire broadcast runtime task settings."""
        return self._call_server(self._compat('expire_broadcast'), **kwargs)

    def get_broadcast(self, **kwargs):
        """Return broadcast settings."""
        return self._call_server(
            self._compat('get_broadcast'), method=self.METHOD_GET, **kwargs)

    def get_info(self, command, **kwargs):
        """Return suite info."""
        return self._call_server(
            self._compat('get_info', default='') + command,
            method=self.METHOD_GET, **kwargs)

    def get_latest_state(self, full_mode=False):
        """Return latest state of the suite (for the GUI)."""
        self._load_contact_info()
        if self.comms1.get(self.srv_files_mgr.KEY_API) == 0:
            # Basic compat for pre-7.5.0 suites
            # Full mode only.
            # Error content/size not supported.
            # Report made-up main loop interval of 5.0 seconds.
            return {
                'cylc_version': self.get_info('get_cylc_version'),
                'full_mode': full_mode,
                'summary': self.get_suite_state_summary(),
                'ancestors': self.get_info('get_first_parent_ancestors'),
                'ancestors_pruned': self.get_info(
                    'get_first_parent_ancestors', pruned=True),
                'descendants': self.get_info('get_first_parent_descendants'),
                'err_content': '',
                'err_size': 0,
                'mean_main_loop_interval': 5.0}
        else:
            return self._call_server(
                'get_latest_state',
                method=self.METHOD_GET, full_mode=full_mode)

    def get_suite_state_summary(self):
        """Return the global, task, and family summary data structures."""
        return utf8_enforce(self._call_server(
            self._compat('get_suite_state_summary'), method=self.METHOD_GET))

    def identify(self):
        """Return suite identity."""
        # Note on compat: Suites on 7.6.0 or above can just call "identify",
        # but has compat for "id/identity".
        return self._call_server('id/identify', method=self.METHOD_GET)

    def put_broadcast(self, payload):
        """Put/set broadcast runtime task settings."""
        return self._call_server(
            self._compat('put_broadcast'), payload=payload)

    def put_command(self, command, **kwargs):
        """Invoke suite command."""
        return self._call_server(
            self._compat('put_command', default='') + command, **kwargs)

    def put_ext_trigger(self, event_message, event_id):
        """Put external trigger."""
        return self._call_server(
            self._compat('put_ext_trigger'),
            event_message=event_message, event_id=event_id)

    def put_messages(self, payload):
        """Send task messages to suite server program.

        Arguments:
            payload (dict):
                task_job (str): Task job as "CYCLE/TASK_NAME/SUBMIT_NUM".
                event_time (str): Event time as string.
                messages (list): List in the form [[severity, message], ...].
        """
        retry_intvl = float(self.comms1.get(
            self.srv_files_mgr.KEY_TASK_MSG_RETRY_INTVL,
            self.MSG_RETRY_INTVL))
        max_tries = int(self.comms1.get(
            self.srv_files_mgr.KEY_TASK_MSG_MAX_TRIES,
            self.MSG_MAX_TRIES))
        for i in range(1, max_tries + 1):  # 1..max_tries inclusive
            orig_timeout = self.timeout
            if self.timeout is None:
                self.timeout = self.MSG_TIMEOUT
            try:
                func_name = self._compat('put_messages')
                if func_name == 'put_messages':
                    results = self._call_server(func_name, payload=payload)
                elif func_name == 'put_message':  # API 1, 7.5.0 compat
                    cycle, name = payload['task_job'].split('/')[0:2]
                    for severity, message in payload['messages']:
                        results.append(self._call_server(
                            func_name, task_id='%s.%s' % (name, cycle),
                            severity=severity, message=message))
                else:  # API 0, pre-7.5.0 compat, priority instead of severity
                    cycle, name = payload['task_job'].split('/')[0:2]
                    for severity, message in payload['messages']:
                        results.append(self._call_server(
                            func_name, task_id='%s.%s' % (name, cycle),
                            priority=severity, message=message))
            except ClientInfoError:
                # Contact info file not found, suite probably not running.
                # Don't bother with retry, suite restart will poll any way.
                raise
            except ClientError as exc:
                now = get_current_time_string()
                sys.stderr.write(
                    "%s WARNING - Message send failed, try %s of %s: %s\n" % (
                        now, i, max_tries, exc))
                if i < max_tries:
                    sys.stderr.write(
                        "   retry in %s seconds, timeout is %s\n" % (
                            retry_intvl, self.timeout))
                    sleep(retry_intvl)
                    # Reset in case contact info or passphrase change
                    self.comms1 = {}
                    self.host = None
                    self.port = None
                    self.auth = None
            else:
                if i > 1:
                    # Continue to write to STDERR, so users can easily see that
                    # it has recovered from previous failures.
                    sys.stderr.write(
                        "%s INFO - Send message: try %s of %s succeeded\n" % (
                            get_current_time_string(), i, max_tries))
                return results
            finally:
                self.timeout = orig_timeout

    def reset(self):
        """Compat method, does nothing."""
        pass

    def signout(self):
        """Tell server to forget this client."""
        return self._call_server('signout')

    def _call_server(self, function, method=METHOD, payload=None, **kwargs):
        """Build server URL + call it"""
        if self.comms2:
            return self._call_server_via_comms2(function, payload, **kwargs)
        url = self._call_server_get_url(function, **kwargs)
        # Remove proxy settings from environment for now
        environ = {}
        for key in ("http_proxy", "https_proxy"):
            val = os.environ.pop(key, None)
            if val:
                environ[key] = val
        try:
            return self.call_server_impl(url, method, payload)
        finally:
            os.environ.update(environ)

    def _call_server_get_url(self, function, **kwargs):
        """Build request URL."""
        scheme = self.comms1.get(self.srv_files_mgr.KEY_COMMS_PROTOCOL)
        if scheme is None:
            # Use standard setting from global configuration
            scheme = glbl_cfg().get(['communication', 'method'])
        url = '%s://%s:%s/%s' % (
            scheme, self.host, self.port, function)
        # If there are any parameters left in the dict after popping,
        # append them to the url.
        if kwargs:
            import urllib
            params = urllib.urlencode(kwargs, doseq=True)
            url += "?" + params
        return url

    def call_server_impl(self, url, method, payload):
        """Determine whether to use requests or urllib2 to call suite API."""
        impl = self._call_server_impl_urllib2
        try:
            import requests
        except ImportError:
            pass
        else:
            if [int(_) for _ in requests.__version__.split(".")] >= [2, 4, 2]:
                impl = self._call_server_impl_requests
        try:
            return impl(url, method, payload)
        except ClientConnectError as exc:
            if self.suite is None:
                raise
            # Cannot connect, perhaps suite is no longer running and is leaving
            # behind a contact file?
            try:
                self.srv_files_mgr.detect_old_contact_file(
                    self.suite, (self.host, self.port))
            except (AssertionError, SuiteServiceFileError):
                raise exc
            else:
                # self.srv_files_mgr.detect_old_contact_file should delete left
                # behind contact file if the old suite process no longer
                # exists. Should be safe to report that the suite has stopped.
                raise ClientConnectError(exc.args[0], exc.STOPPED % self.suite)

    def _call_server_impl_requests(self, url, method, payload):
        """Call server with "requests" library."""
        import requests
        from requests.packages.urllib3.exceptions import InsecureRequestWarning
        warnings.simplefilter("ignore", InsecureRequestWarning)
        if self.session is None:
            self.session = requests.Session()

        if method == self.METHOD_POST:
            session_method = self.session.post
        else:
            session_method = self.session.get
        scheme = url.split(':', 1)[0]  # Can use urlparse?
        username, password, verify = self._get_auth(scheme)
        try:
            ret = session_method(
                url,
                json=payload,
                verify=verify,
                proxies={},
                headers=self._get_headers(),
                auth=requests.auth.HTTPDigestAuth(username, password),
                timeout=self.timeout
            )
        except requests.exceptions.SSLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                raise CylcError(
                    "Cannot issue commands over unsecured http.")
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectError(url, exc)
        except requests.exceptions.Timeout as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientTimeout(url, exc)
        except requests.exceptions.RequestException as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectError(url, exc)
        if ret.status_code == 401:
            access_desc = 'private'
            if self.auth == self.ANON_AUTH:
                access_desc = 'public'
            raise ClientDeniedError(url, self.prog_name, access_desc)
        if ret.status_code >= 400:
            exception_text = get_exception_from_html(ret.text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(ret.text)
        try:
            ret.raise_for_status()
        except requests.exceptions.HTTPError as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientConnectedError(url, exc)
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])
        try:
            return ret.json()
        except ValueError:
            return ret.text

    def _call_server_impl_urllib2(self, url, method, payload):
        """Call server with "urllib2" library."""
        import json
        import urllib2
        import ssl
        unverified_context = getattr(ssl, '_create_unverified_context', None)
        if unverified_context is not None:
            ssl._create_default_https_context = unverified_context

        scheme = url.split(':', 1)[0]  # Can use urlparse?
        username, password = self._get_auth(scheme)[0:2]
        auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
        auth_manager.add_password(None, url, username, password)
        auth = urllib2.HTTPDigestAuthHandler(auth_manager)
        opener = urllib2.build_opener(auth, urllib2.HTTPSHandler())
        headers_list = self._get_headers().items()
        if payload:
            payload = json.dumps(payload)
            headers_list.append(('Accept', 'application/json'))
            json_headers = {'Content-Type': 'application/json',
                            'Content-Length': len(payload)}
        else:
            payload = None
            json_headers = {'Content-Length': 0}
        opener.addheaders = headers_list
        req = urllib2.Request(url, payload, json_headers)

        # This is an unpleasant monkey patch, but there isn't an
        # alternative. urllib2 uses POST if there is a data payload
        # but that is not the correct criterion.
        # The difference is basically that POST changes
        # server state and GET doesn't.
        req.get_method = lambda: method
        try:
            response = opener.open(req, timeout=self.timeout)
        except urllib2.URLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(self.ERROR_NO_HTTPS_SUPPORT.format(exc))
                raise CylcError(
                    "Cannot issue commands over unsecured http.")
            if cylc.flags.debug:
                traceback.print_exc()
            if "timed out" in str(exc):
                raise ClientTimeout(url, exc)
            else:
                raise ClientConnectError(url, exc)
        except Exception as exc:
            if cylc.flags.debug:
                traceback.print_exc()
            raise ClientError(url, exc)

        if response.getcode() == 401:
            access_desc = 'private'
            if self.auth == self.ANON_AUTH:
                access_desc = 'public'
            raise ClientDeniedError(url, self.prog_name, access_desc)
        response_text = response.read()
        if response.getcode() >= 400:
            exception_text = get_exception_from_html(response_text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(response_text)
            raise ClientConnectedError(
                url,
                "%s HTTP return code" % response.getcode())
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])

        try:
            return json.loads(response_text)
        except ValueError:
            return response_text

    def _call_server_via_comms2(self, function, payload, **kwargs):
        """Call server via "cylc client --use-ssh".

        Call "cylc client --use-ssh" using `subprocess.Popen`. Payload and
        arguments of the API method are serialized as JSON and are written to a
        temporary file, which is then used as the STDIN of the "cylc client"
        command. The external call here should be even safer than a direct
        HTTP(S) call, as it can be blocked by SSH before it even gets a chance
        to make the subsequent HTTP(S) call.

        Arguments:
            function (str): name of API method, argument 1 of "cylc client".
            payload (str): extra data or information for the API method.
            **kwargs (dict): arguments for the API method.
        """
        import json
        from cylc.remote import remote_cylc_cmd
        command = ["client", function, self.suite]
        if payload:
            kwargs["payload"] = payload
        if kwargs:
            from tempfile import TemporaryFile
            stdin = TemporaryFile()
            json.dump(kwargs, stdin)
            stdin.seek(0)
        else:
            # With stdin=None, `remote_cylc_cmd` will:
            # * Set stdin to open(os.devnull)
            # * Add `-n` to the SSH command
            stdin = None
        proc = remote_cylc_cmd(
            command, self.owner, self.host, capture=True,
            ssh_login_shell=(self.comms1.get(
                self.srv_files_mgr.KEY_SSH_USE_LOGIN_SHELL
            ) in ['True', 'true']),
            ssh_cylc=(r'%s/bin/cylc' % self.comms1.get(
                self.srv_files_mgr.KEY_DIR_ON_SUITE_HOST)
            ),
            stdin=stdin,
        )
        out = proc.communicate()[0]
        return_code = proc.wait()
        if return_code:
            from pipes import quote
            command_str = " ".join(quote(item) for item in command)
            raise ClientError(command_str, "return-code=%d" % return_code)
        return json.loads(out)

    def _get_auth(self, protocol):
        """Return a user/password Digest Auth."""
        if self.auth is None:
            self.auth = self.ANON_AUTH
            try:
                pphrase = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    self.suite, self.owner, self.host, content=True)
                if protocol == 'https':
                    verify = self.srv_files_mgr.get_auth_item(
                        self.srv_files_mgr.FILE_BASE_SSL_CERT,
                        self.suite, self.owner, self.host)
                else:
                    verify = False
            except SuiteServiceFileError:
                pass
            else:
                if pphrase and pphrase != NO_PASSPHRASE:
                    self.auth = ('cylc', pphrase, verify)
        return self.auth

    def _get_headers(self):
        """Return HTTP headers identifying the client."""
        user_agent_string = (
            "cylc/%s prog_name/%s uuid/%s" % (
                CYLC_VERSION, self.prog_name, self.my_uuid
            )
        )
        auth_info = "%s@%s" % (get_user(), get_host())
        return {"User-Agent": user_agent_string,
                "From": auth_info}

    def _load_contact_info(self):
        """Obtain suite owner, host, port info.

        Determine host and port using content in port file, unless already
        specified.
        """
        if self.host and self.port:
            return
        if self.port:
            # In case the contact file is corrupted, user can specify the port.
            self.host = get_host()
            return
        try:
            # Always trust the values in the contact file otherwise.
            self.comms1 = self.srv_files_mgr.load_contact_file(
                self.suite, self.owner, self.host)
            # Port inside "try" block, as it needs a type conversion
            self.port = int(self.comms1.get(self.srv_files_mgr.KEY_PORT))
        except (IOError, ValueError, SuiteServiceFileError):
            raise ClientInfoError(self.suite)
        else:
            # Check mismatch suite UUID
            env_suite = os.getenv(self.srv_files_mgr.KEY_NAME)
            env_uuid = os.getenv(self.srv_files_mgr.KEY_UUID)
            if (self.suite and env_suite and env_suite == self.suite and
                    env_uuid and
                    env_uuid != self.comms1.get(self.srv_files_mgr.KEY_UUID)):
                raise ClientInfoUUIDError(
                    env_uuid, self.comms1[self.srv_files_mgr.KEY_UUID])
            # All good
            self.host = self.comms1.get(self.srv_files_mgr.KEY_HOST)
            self.owner = self.comms1.get(self.srv_files_mgr.KEY_OWNER)
            if self.srv_files_mgr.KEY_API not in self.comms1:
                self.comms1[self.srv_files_mgr.KEY_API] = 0  # <=7.5.0 compat
        # Indirect comms settings
        self.comms2.clear()
        try:
            self.comms2.update(self.srv_files_mgr.load_contact_file(
                self.suite, self.owner, self.host,
                SuiteSrvFilesManager.FILE_BASE_CONTACT2))
        except SuiteServiceFileError:
            pass
Example #8
0
class CommsDaemon(object):
    """Wrap HTTPS daemon for a suite."""
    def __init__(self, suite):
        # Suite only needed for back-compat with old clients (see below):
        self.suite = suite

        # Figure out the ports we are allowed to use.
        base_port = GLOBAL_CFG.get(['communication', 'base port'])
        max_ports = GLOBAL_CFG.get(
            ['communication', 'maximum number of ports'])
        self.ok_ports = range(int(base_port), int(base_port) + int(max_ports))
        random.shuffle(self.ok_ports)

        comms_options = GLOBAL_CFG.get(['communication', 'options'])
        # HTTP Digest Auth uses MD5 - pretty secure in this use case.
        # Extending it with extra algorithms is allowed, but won't be
        # supported by most browsers. requests and urllib2 are OK though.
        self.hash_algorithm = "MD5"
        if "SHA1" in comms_options:
            # Note 'SHA' rather than 'SHA1'.
            self.hash_algorithm = "SHA"

        self.srv_files_mgr = SuiteSrvFilesManager()
        self.get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(
            {
                'cylc':
                self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    suite,
                    content=True),
                'anon':
                NO_PASSPHRASE
            },
            algorithm=self.hash_algorithm)
        try:
            self.cert = self.srv_files_mgr.get_auth_item(
                self.srv_files_mgr.FILE_BASE_SSL_CERT, suite)
            self.pkey = self.srv_files_mgr.get_auth_item(
                self.srv_files_mgr.FILE_BASE_SSL_PEM, suite)
        except SuiteServiceFileError:
            self.cert = None
            self.pkey = None
        self.client_reporter = CommsClientReporter.get_inst()
        self.start()

    def shutdown(self):
        """Shutdown the daemon."""
        if hasattr(self, "engine"):
            self.engine.exit()
            self.engine.block()

    def connect(self, obj, name):
        """Connect obj and name to the daemon."""
        cherrypy.tree.mount(obj, "/" + name)

    def disconnect(self, obj):
        """Disconnect obj from the daemon."""
        pass

    def get_port(self):
        """Return the daemon port."""
        return self.port

    def report_connection_if_denied(self):
        """Report connection if denied."""
        self.client_reporter.report_connection_if_denied()

    def start(self):
        """Start quick web service."""
        # cherrypy.config["tools.encode.on"] = True
        # cherrypy.config["tools.encode.encoding"] = "utf-8"
        cherrypy.config["server.socket_host"] = '0.0.0.0'
        cherrypy.config["engine.autoreload.on"] = False
        try:
            from OpenSSL import SSL, crypto
            cherrypy.config['server.ssl_module'] = 'pyopenSSL'
            cherrypy.config['server.ssl_certificate'] = self.cert
            cherrypy.config['server.ssl_private_key'] = self.pkey
        except ImportError:
            ERR.warning("no HTTPS/OpenSSL support")
        cherrypy.config['log.screen'] = None
        key = binascii.hexlify(os.urandom(16))
        cherrypy.config.update({
            'tools.auth_digest.on':
            True,
            'tools.auth_digest.realm':
            self.suite,
            'tools.auth_digest.get_ha1':
            self.get_ha1,
            'tools.auth_digest.key':
            key,
            'tools.auth_digest.algorithm':
            self.hash_algorithm
        })
        cherrypy.tools.connect_log = cherrypy.Tool(
            'on_end_resource', self.report_connection_if_denied)
        cherrypy.config['tools.connect_log.on'] = True
        self.engine = cherrypy.engine
        for port in self.ok_ports:
            cherrypy.config["server.socket_port"] = port
            try:
                cherrypy.engine.start()
                cherrypy.engine.wait(cherrypy.engine.states.STARTED)
            except Exception:
                if cylc.flags.debug:
                    traceback.print_exc()
                # We need to reinitialise the httpserver for each port attempt.
                cherrypy.server.httpserver = None
            else:
                if cherrypy.engine.state == cherrypy.engine.states.STARTED:
                    self.port = port
                    return
        raise Exception("No available ports")
Example #9
0
class BaseCommsClient(object):
    """Base class for client-side suite object interfaces."""

    ACCESS_DESCRIPTION = 'private'
    METHOD = 'POST'
    METHOD_POST = 'POST'
    METHOD_GET = 'GET'

    def __init__(self, suite, owner=USER, host=None, port=None, timeout=None,
                 my_uuid=None, print_uuid=False):
        self.suite = suite
        self.owner = owner
        self.host = host
        self.port = port
        if timeout is not None:
            timeout = float(timeout)
        self.timeout = timeout
        self.my_uuid = my_uuid or uuid4()
        if print_uuid:
            print >> sys.stderr, '%s' % self.my_uuid
        self.srv_files_mgr = SuiteSrvFilesManager()
        self.prog_name = os.path.basename(sys.argv[0])
        self.server_cert = None
        self.auth = None

    def call_server_func(self, category, fname, **fargs):
        """Call server_object.fname(*fargs, **fargs)."""
        if self.host is None and self.port is not None:
            self.host = get_hostname()
        try:
            self._load_contact_info()
        except (IOError, ValueError, SuiteServiceFileError):
            raise ConnectionInfoError(self.suite)
        handle_proxies()
        payload = fargs.pop("payload", None)
        method = fargs.pop("method", self.METHOD)
        host = self.host
        if host == "localhost":
            host = get_hostname().split(".")[0]
        url = 'https://%s:%s/%s/%s' % (host, self.port, category, fname)
        if fargs:
            import urllib
            params = urllib.urlencode(fargs, doseq=True)
            url += "?" + params
        return self._get_data_from_url(url, payload, method=method)

    def _get_data_from_url(self, url, json_data, method=None):
        requests_ok = True
        try:
            import requests
        except ImportError:
            requests_ok = False
        else:
            version = [int(_) for _ in requests.__version__.split(".")]
            if version < [2, 4, 2]:
                requests_ok = False
        if requests_ok:
            return self._get_data_from_url_with_requests(
                url, json_data, method=method)
        return self._get_data_from_url_with_urllib2(
            url, json_data, method=method)

    def _get_data_from_url_with_requests(self, url, json_data, method=None):
        import requests
        username, password = self._get_auth()
        auth = requests.auth.HTTPDigestAuth(username, password)
        if not hasattr(self, "session"):
            self.session = requests.Session()
        if method is None:
            method = self.METHOD
        if method == self.METHOD_POST:
            session_method = self.session.post
        else:
            session_method = self.session.get
        try:
            ret = session_method(
                url,
                json=json_data,
                verify=self._get_verify(),
                proxies={},
                headers=self._get_headers(),
                auth=auth,
                timeout=self.timeout
            )
        except requests.exceptions.SSLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(WARNING_NO_HTTPS_SUPPORT.format(exc))
                return self._get_data_from_url_with_requests(
                    url.replace("https:", "http:", 1), json_data)
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)
        except requests.exceptions.Timeout as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionTimeout(url, exc)
        except requests.exceptions.RequestException as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)
        if ret.status_code == 401:
            raise ConnectionDeniedError(url, self.prog_name,
                                        self.ACCESS_DESCRIPTION)
        if ret.status_code >= 400:
            from cylc.network.https.util import get_exception_from_html
            exception_text = get_exception_from_html(ret.text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(ret.text)
        try:
            ret.raise_for_status()
        except requests.exceptions.HTTPError as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])
        try:
            return ret.json()
        except ValueError:
            return ret.text

    def _get_data_from_url_with_urllib2(self, url, json_data, method=None):
        import json
        import urllib2
        import ssl
        if hasattr(ssl, '_create_unverified_context'):
            ssl._create_default_https_context = ssl._create_unverified_context
        if method is None:
            method = self.METHOD
        orig_json_data = json_data
        username, password = self._get_auth()
        auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
        auth_manager.add_password(None, url, username, password)
        auth = urllib2.HTTPDigestAuthHandler(auth_manager)
        opener = urllib2.build_opener(auth, urllib2.HTTPSHandler())
        headers_list = self._get_headers().items()
        if json_data:
            json_data = json.dumps(json_data)
            headers_list.append(('Accept', 'application/json'))
            json_headers = {'Content-Type': 'application/json',
                            'Content-Length': len(json_data)}
        else:
            json_data = None
            json_headers = {'Content-Length': 0}
        opener.addheaders = headers_list
        req = urllib2.Request(url, json_data, json_headers)

        # This is an unpleasant monkey patch, but there isn't an alternative.
        # urllib2 uses POST if there is a data payload, but that is not the
        # correct criterion. The difference is basically that POST changes
        # server state and GET doesn't.
        req.get_method = lambda: method
        try:
            response = opener.open(req, timeout=self.timeout)
        except urllib2.URLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(WARNING_NO_HTTPS_SUPPORT.format(exc))
                return self._get_data_from_url_with_urllib2(
                    url.replace("https:", "http:", 1), orig_json_data)
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            if "timed out" in str(exc):
                raise ConnectionTimeout(url, exc)
            else:
                raise ConnectionError(url, exc)
        except Exception as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)

        if response.getcode() == 401:
            raise ConnectionDeniedError(url, self.prog_name,
                                        self.ACCESS_DESCRIPTION)
        response_text = response.read()
        if response.getcode() >= 400:
            from cylc.network.https.util import get_exception_from_html
            exception_text = get_exception_from_html(response_text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(response_text)
            raise ConnectionError(url,
                                  "%s HTTP return code" % response.getcode())
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])
        try:
            return json.loads(response_text)
        except ValueError:
            return response_text

    def _get_auth(self):
        """Return a user/password Digest Auth."""
        if self.auth is None:
            self.auth = ('anon', NO_PASSPHRASE)
            try:
                pphrase = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    self.suite, self.owner, self.host, content=True)
            except SuiteServiceFileError:
                pass
            else:
                if pphrase and pphrase != NO_PASSPHRASE:
                    self.auth = ('cylc', pphrase)
        return self.auth

    def _get_headers(self):
        """Return HTTP headers identifying the client."""
        user_agent_string = (
            "cylc/%s prog_name/%s uuid/%s" % (
                CYLC_VERSION, self.prog_name, self.my_uuid
            )
        )
        auth_info = "%s@%s" % (USER, get_hostname())
        return {"User-Agent": user_agent_string,
                "From": auth_info}

    def _get_verify(self):
        """Return the server certificate if possible."""
        if self.server_cert is None:
            try:
                self.server_cert = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_SSL_CERT,
                    self.suite, self.owner, self.host)
            except SuiteServiceFileError:
                self.server_cert = False
        return self.server_cert

    def _load_contact_info(self):
        """Obtain suite owner, host, port info.

        Determine host and port using content in port file, unless already
        specified.
        """
        if self.host and self.port:
            return
        data = self.srv_files_mgr.load_contact_file(
            self.suite, self.owner, self.host)
        if not self.host:
            self.host = data.get(self.srv_files_mgr.KEY_HOST)
        if not self.port:
            self.port = int(data.get(self.srv_files_mgr.KEY_PORT))
        if not self.owner:
            self.owner = data.get(self.srv_files_mgr.KEY_OWNER)

    def reset(self, *args, **kwargs):
        pass

    def signout(self, *args, **kwargs):
        pass
Example #10
0
class RemoteJobHostManager(object):
    """Manage a remote job host."""

    _INSTANCE = None

    @classmethod
    def get_inst(cls):
        """Return a singleton instance of this class."""
        if cls._INSTANCE is None:
            cls._INSTANCE = cls()
        return cls._INSTANCE

    def __init__(self):
        self.initialised = {}  # {(user, host): should_unlink, ...}
        self.single_task_mode = False
        self.suite_srv_files_mgr = SuiteSrvFilesManager()

    def init_suite_run_dir(self, reg, host, owner):
        """Initialise suite run dir on a user@host.

        Create SUITE_RUN_DIR/log/job/ if necessary.
        Install suite contact environment file.
        Install suite python modules.

        Raise RemoteJobHostInitError if initialisation cannot complete.

        """
        if host is None:
            host = 'localhost'
        if ((host, owner) in [('localhost', None), ('localhost', USER)]
                or (host, owner) in self.initialised or self.single_task_mode):
            return
        user_at_host = host
        if owner:
            user_at_host = owner + '@' + host

        r_suite_run_dir = GLOBAL_CFG.get_derived_host_item(
            reg, 'suite run directory', host, owner)
        r_log_job_dir = GLOBAL_CFG.get_derived_host_item(
            reg, 'suite job log directory', host, owner)
        r_suite_srv_dir = os.path.join(r_suite_run_dir,
                                       self.suite_srv_files_mgr.DIR_BASE_SRV)

        # Create a UUID file in the service directory.
        # If remote host has the file in its service directory, we can assume
        # that the remote host has a shared file system with the suite host.
        ssh_tmpl = GLOBAL_CFG.get_host_item('remote shell template', host,
                                            owner)
        uuid_str = str(uuid4())
        uuid_fname = os.path.join(
            self.suite_srv_files_mgr.get_suite_srv_dir(reg), uuid_str)
        try:
            open(uuid_fname, 'wb').close()
            proc = Popen(shlex.split(ssh_tmpl) + [
                '-n', user_at_host, 'test', '-e',
                os.path.join(r_suite_srv_dir, uuid_str)
            ],
                         stdout=PIPE,
                         stderr=PIPE)
            if proc.wait() == 0:
                # Initialised, but no need to tidy up
                self.initialised[(host, owner)] = False
                return
        finally:
            try:
                os.unlink(uuid_fname)
            except OSError:
                pass

        cmds = []
        # Command to create suite directory structure on remote host.
        cmds.append(
            shlex.split(ssh_tmpl) + [
                '-n', user_at_host, 'mkdir', '-p', r_suite_run_dir,
                r_log_job_dir, r_suite_srv_dir
            ])
        # Command to copy contact and authentication files to remote host.
        # Note: no need to do this if task communication method is "poll".
        should_unlink = GLOBAL_CFG.get_host_item('task communication method',
                                                 host, owner) != "poll"
        if should_unlink:
            scp_tmpl = GLOBAL_CFG.get_host_item('remote copy template', host,
                                                owner)
            cmds.append(
                shlex.split(scp_tmpl) + [
                    '-p',
                    self.suite_srv_files_mgr.get_contact_file(reg),
                    self.suite_srv_files_mgr.get_auth_item(
                        self.suite_srv_files_mgr.FILE_BASE_PASSPHRASE, reg),
                    self.suite_srv_files_mgr.get_auth_item(
                        self.suite_srv_files_mgr.FILE_BASE_SSL_CERT, reg),
                    user_at_host + ':' + r_suite_srv_dir + '/'
                ])
        # Command to copy python library to remote host.
        suite_run_py = os.path.join(
            GLOBAL_CFG.get_derived_host_item(reg, 'suite run directory'),
            'python')
        if os.path.isdir(suite_run_py):
            cmds.append(
                shlex.split(scp_tmpl) + [
                    '-pr', suite_run_py, user_at_host + ':' + r_suite_run_dir +
                    '/'
                ])
        # Run commands in sequence.
        for cmd in cmds:
            proc = Popen(cmd, stdout=PIPE, stderr=PIPE)
            out, err = proc.communicate()
            if proc.wait():
                raise RemoteJobHostInitError(
                    RemoteJobHostInitError.MSG_INIT, user_at_host,
                    ' '.join([quote(item) for item in cmd]), proc.returncode,
                    out, err)
        self.initialised[(host, owner)] = should_unlink
        LOG.info('Initialised %s:%s' % (user_at_host, r_suite_run_dir))

    def unlink_suite_contact_files(self, reg):
        """Remove suite contact files from initialised hosts.

        This is called on shutdown, so we don't want anything to hang.
        Terminate any incomplete SSH commands after 10 seconds.
        """
        # Issue all SSH commands in parallel
        procs = {}
        for (host, owner), should_unlink in self.initialised.items():
            if not should_unlink:
                continue
            user_at_host = host
            if owner:
                user_at_host = owner + '@' + host
            ssh_tmpl = GLOBAL_CFG.get_host_item('remote shell template', host,
                                                owner)
            r_suite_contact_file = os.path.join(
                GLOBAL_CFG.get_derived_host_item(reg, 'suite run directory',
                                                 host, owner),
                SuiteSrvFilesManager.DIR_BASE_SRV,
                SuiteSrvFilesManager.FILE_BASE_CONTACT)
            cmd = shlex.split(ssh_tmpl) + [
                '-n', user_at_host, 'rm', '-f', r_suite_contact_file
            ]
            procs[user_at_host] = (cmd, Popen(cmd, stdout=PIPE, stderr=PIPE))
        # Wait for commands to complete for a max of 10 seconds
        timeout = time() + 10.0
        while procs and time() < timeout:
            for user_at_host, (cmd, proc) in procs.copy().items():
                if proc.poll() is None:
                    continue
                del procs[user_at_host]
                out, err = proc.communicate()
                if proc.wait():
                    ERR.warning(
                        RemoteJobHostInitError(
                            RemoteJobHostInitError.MSG_TIDY, user_at_host,
                            ' '.join([quote(item) for item in cmd]),
                            proc.returncode, out, err))
        # Terminate any remaining commands
        for user_at_host, (cmd, proc) in procs.items():
            try:
                proc.terminate()
            except OSError:
                pass
            out, err = proc.communicate()
            if proc.wait():
                ERR.warning(
                    RemoteJobHostInitError(
                        RemoteJobHostInitError.MSG_TIDY, user_at_host,
                        ' '.join([quote(item) for item in cmd]),
                        proc.returncode, out, err))
Example #11
0
class CommsDaemon(object):
    """Wrap HTTPS daemon for a suite."""

    def __init__(self, suite):
        # Suite only needed for back-compat with old clients (see below):
        self.suite = suite

        # Figure out the ports we are allowed to use.
        base_port = GLOBAL_CFG.get(['communication', 'base port'])
        max_ports = GLOBAL_CFG.get(
            ['communication', 'maximum number of ports'])
        self.ok_ports = range(int(base_port), int(base_port) + int(max_ports))
        random.shuffle(self.ok_ports)

        comms_options = GLOBAL_CFG.get(['communication', 'options'])
        # HTTP Digest Auth uses MD5 - pretty secure in this use case.
        # Extending it with extra algorithms is allowed, but won't be
        # supported by most browsers. requests and urllib2 are OK though.
        self.hash_algorithm = "MD5"
        if "SHA1" in comms_options:
            # Note 'SHA' rather than 'SHA1'.
            self.hash_algorithm = "SHA"

        self.srv_files_mgr = SuiteSrvFilesManager()
        self.get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(
            {
                'cylc': self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    suite, content=True),
                'anon': NO_PASSPHRASE
            },
            algorithm=self.hash_algorithm)
        try:
            self.cert = self.srv_files_mgr.get_auth_item(
                self.srv_files_mgr.FILE_BASE_SSL_CERT, suite)
            self.pkey = self.srv_files_mgr.get_auth_item(
                self.srv_files_mgr.FILE_BASE_SSL_PEM, suite)
        except SuiteServiceFileError:
            self.cert = None
            self.pkey = None
        self.client_reporter = CommsClientReporter.get_inst()
        self.start()

    def shutdown(self):
        """Shutdown the daemon."""
        if hasattr(self, "engine"):
            self.engine.exit()
            self.engine.block()

    def connect(self, obj, name):
        """Connect obj and name to the daemon."""
        cherrypy.tree.mount(obj, "/" + name)

    def disconnect(self, obj):
        """Disconnect obj from the daemon."""
        pass

    def get_port(self):
        """Return the daemon port."""
        return self.port

    def report_connection_if_denied(self):
        """Report connection if denied."""
        self.client_reporter.report_connection_if_denied()

    def start(self):
        """Start quick web service."""
        # cherrypy.config["tools.encode.on"] = True
        # cherrypy.config["tools.encode.encoding"] = "utf-8"
        cherrypy.config["server.socket_host"] = '0.0.0.0'
        cherrypy.config["engine.autoreload.on"] = False
        try:
            from OpenSSL import SSL, crypto
            cherrypy.config['server.ssl_module'] = 'pyopenSSL'
            cherrypy.config['server.ssl_certificate'] = self.cert
            cherrypy.config['server.ssl_private_key'] = self.pkey
        except ImportError:
            ERR.warning("no HTTPS/OpenSSL support")
        cherrypy.config['log.screen'] = None
        key = binascii.hexlify(os.urandom(16))
        cherrypy.config.update({
            'tools.auth_digest.on': True,
            'tools.auth_digest.realm': self.suite,
            'tools.auth_digest.get_ha1': self.get_ha1,
            'tools.auth_digest.key': key,
            'tools.auth_digest.algorithm': self.hash_algorithm
        })
        cherrypy.tools.connect_log = cherrypy.Tool(
            'on_end_resource', self.report_connection_if_denied)
        cherrypy.config['tools.connect_log.on'] = True
        self.engine = cherrypy.engine
        for port in self.ok_ports:
            cherrypy.config["server.socket_port"] = port
            try:
                cherrypy.engine.start()
                cherrypy.engine.wait(cherrypy.engine.states.STARTED)
            except Exception:
                if cylc.flags.debug:
                    traceback.print_exc()
                # We need to reinitialise the httpserver for each port attempt.
                cherrypy.server.httpserver = None
            else:
                if cherrypy.engine.state == cherrypy.engine.states.STARTED:
                    self.port = port
                    return
        raise Exception("No available ports")
Example #12
0
class BaseCommsClient(object):
    """Base class for client-side suite object interfaces."""

    ACCESS_DESCRIPTION = 'private'
    METHOD = 'POST'
    METHOD_POST = 'POST'
    METHOD_GET = 'GET'

    def __init__(self, suite, owner=USER, host=None, port=None, timeout=None,
                 my_uuid=None, print_uuid=False):
        self.suite = suite
        self.owner = owner
        self.host = host
        self.port = port
        if timeout is not None:
            timeout = float(timeout)
        self.timeout = timeout
        self.my_uuid = my_uuid or uuid4()
        if print_uuid:
            print >> sys.stderr, '%s' % self.my_uuid
        self.srv_files_mgr = SuiteSrvFilesManager()
        self.prog_name = os.path.basename(sys.argv[0])
        self.server_cert = None
        self.auth = None

    def call_server_func(self, category, fname, **fargs):
        """Call server_object.fname(*fargs, **fargs)."""
        if self.host is None and self.port is not None:
            self.host = get_hostname()
        if self.host is None or self.port is None:
            try:
                self._load_contact_info()
            except (IOError, ValueError, SuiteServiceFileError):
                raise ConnectionInfoError(self.suite)
        handle_proxies()
        payload = fargs.pop("payload", None)
        method = fargs.pop("method", self.METHOD)
        host = self.host
        if not self.host.split(".")[0].isdigit():
            host = self.host.split(".")[0]
        if host == "localhost":
            host = get_hostname().split(".")[0]
        url = 'https://%s:%s/%s/%s' % (host, self.port, category, fname)
        if fargs:
            import urllib
            params = urllib.urlencode(fargs, doseq=True)
            url += "?" + params
        return self._get_data_from_url(url, payload, method=method)

    def _get_data_from_url(self, url, json_data, method=None):
        requests_ok = True
        try:
            import requests
        except ImportError:
            requests_ok = False
        else:
            version = [int(_) for _ in requests.__version__.split(".")]
            if version < [2, 4, 2]:
                requests_ok = False
        if requests_ok:
            return self._get_data_from_url_with_requests(
                url, json_data, method=method)
        return self._get_data_from_url_with_urllib2(
            url, json_data, method=method)

    def _get_data_from_url_with_requests(self, url, json_data, method=None):
        import requests
        username, password = self._get_auth()
        auth = requests.auth.HTTPDigestAuth(username, password)
        if not hasattr(self, "session"):
            self.session = requests.Session()
        if method is None:
            method = self.METHOD
        if method == self.METHOD_POST:
            session_method = self.session.post
        else:
            session_method = self.session.get
        try:
            ret = session_method(
                url,
                json=json_data,
                verify=self._get_verify(),
                proxies={},
                headers=self._get_headers(),
                auth=auth,
                timeout=self.timeout
            )
        except requests.exceptions.SSLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(WARNING_NO_HTTPS_SUPPORT.format(exc))
                return self._get_data_from_url_with_requests(
                    url.replace("https:", "http:", 1), json_data)
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)
        except requests.exceptions.Timeout as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionTimeout(url, exc)
        except requests.exceptions.RequestException as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)
        if ret.status_code == 401:
            raise ConnectionDeniedError(url, self.prog_name,
                                        self.ACCESS_DESCRIPTION)
        if ret.status_code >= 400:
            from cylc.network.https.util import get_exception_from_html
            exception_text = get_exception_from_html(ret.text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(ret.text)
        try:
            ret.raise_for_status()
        except requests.exceptions.HTTPError as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])
        try:
            return ret.json()
        except ValueError:
            return ret.text

    def _get_data_from_url_with_urllib2(self, url, json_data, method=None):
        import json
        import urllib2
        import ssl
        if hasattr(ssl, '_create_unverified_context'):
            ssl._create_default_https_context = ssl._create_unverified_context
        if method is None:
            method = self.METHOD
        orig_json_data = json_data
        username, password = self._get_auth()
        auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm()
        auth_manager.add_password(None, url, username, password)
        auth = urllib2.HTTPDigestAuthHandler(auth_manager)
        opener = urllib2.build_opener(auth, urllib2.HTTPSHandler())
        headers_list = self._get_headers().items()
        if json_data:
            json_data = json.dumps(json_data)
            headers_list.append(('Accept', 'application/json'))
            json_headers = {'Content-Type': 'application/json',
                            'Content-Length': len(json_data)}
        else:
            json_data = None
            json_headers = {'Content-Length': 0}
        opener.addheaders = headers_list
        req = urllib2.Request(url, json_data, json_headers)

        # This is an unpleasant monkey patch, but there isn't an alternative.
        # urllib2 uses POST iff there is a data payload, but that is not the
        # correct criterion. The difference is basically that POST changes
        # server state and GET doesn't.
        req.get_method = lambda: method
        try:
            response = opener.open(req, timeout=self.timeout)
        except urllib2.URLError as exc:
            if "unknown protocol" in str(exc) and url.startswith("https:"):
                # Server is using http rather than https, for some reason.
                sys.stderr.write(WARNING_NO_HTTPS_SUPPORT.format(exc))
                return self._get_data_from_url_with_urllib2(
                    url.replace("https:", "http:", 1), orig_json_data)
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            if "timed out" in str(exc):
                raise ConnectionTimeout(url, exc)
            else:
                raise ConnectionError(url, exc)
        except Exception as exc:
            if cylc.flags.debug:
                import traceback
                traceback.print_exc()
            raise ConnectionError(url, exc)

        if response.getcode() == 401:
            raise ConnectionDeniedError(url, self.prog_name,
                                        self.ACCESS_DESCRIPTION)
        response_text = response.read()
        if response.getcode() >= 400:
            from cylc.network.https.util import get_exception_from_html
            exception_text = get_exception_from_html(response_text)
            if exception_text:
                sys.stderr.write(exception_text)
            else:
                sys.stderr.write(response_text)
            raise ConnectionError(url,
                                  "%s HTTP return code" % response.getcode())
        if self.auth and self.auth[1] != NO_PASSPHRASE:
            self.srv_files_mgr.cache_passphrase(
                self.suite, self.owner, self.host, self.auth[1])
        try:
            return json.loads(response_text)
        except ValueError:
            return response_text

    def _get_auth(self):
        """Return a user/password Digest Auth."""
        if self.auth is None:
            self.auth = ('anon', NO_PASSPHRASE)
            try:
                pphrase = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    self.suite, self.owner, self.host, content=True)
            except SuiteServiceFileError:
                pass
            else:
                if pphrase and pphrase != NO_PASSPHRASE:
                    self.auth = ('cylc', pphrase)
        return self.auth

    def _get_headers(self):
        """Return HTTP headers identifying the client."""
        user_agent_string = (
            "cylc/%s prog_name/%s uuid/%s" % (
                CYLC_VERSION, self.prog_name, self.my_uuid
            )
        )
        auth_info = "%s@%s" % (USER, get_hostname())
        return {"User-Agent": user_agent_string,
                "From": auth_info}

    def _get_verify(self):
        """Return the server certificate if possible."""
        if self.server_cert is None:
            try:
                self.server_cert = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_SSL_CERT,
                    self.suite, self.owner, self.host)
            except SuiteServiceFileError:
                self.server_cert = False
        return self.server_cert

    def _load_contact_info(self):
        """Obtain suite owner, host, port info.

        Determine host and port using content in port file, unless already
        specified.
        """
        if self.host and self.port:
            return
        data = self.srv_files_mgr.load_contact_file(
            self.suite, self.owner, self.host)
        if not self.host:
            self.host = data.get(self.srv_files_mgr.KEY_HOST)
        if not self.port:
            self.port = int(data.get(self.srv_files_mgr.KEY_PORT))
        if not self.owner:
            self.owner = data.get(self.srv_files_mgr.KEY_OWNER)

    def reset(self, *args, **kwargs):
        pass

    def signout(self, *args, **kwargs):
        pass
Example #13
0
class HTTPServer(object):
    """HTTP(S) server by cherrypy, for serving suite runtime API."""

    API = 2
    LOG_CONNECT_DENIED_TMPL = "[client-connect] DENIED %s@%s:%s %s"
    RE_MESSAGE_TIME = re.compile(
        r'\A(.+) at (' + RE_DATE_TIME_FORMAT_EXTENDED + r')\Z', re.DOTALL)

    def __init__(self, suite):
        # Suite only needed for back-compat with old clients (see below):
        self.suite = suite
        self.engine = None
        self.port = None

        # Figure out the ports we are allowed to use.
        base_port = glbl_cfg().get(['communication', 'base port'])
        max_ports = glbl_cfg().get(
            ['communication', 'maximum number of ports'])
        self.ok_ports = range(int(base_port), int(base_port) + int(max_ports))
        random.shuffle(self.ok_ports)

        comms_options = glbl_cfg().get(['communication', 'options'])

        # HTTP Digest Auth uses MD5 - pretty secure in this use case.
        # Extending it with extra algorithms is allowed, but won't be
        # supported by most browsers. requests and urllib2 are OK though.
        self.hash_algorithm = "MD5"
        if "SHA1" in comms_options:
            # Note 'SHA' rather than 'SHA1'.
            self.hash_algorithm = "SHA"

        self.srv_files_mgr = SuiteSrvFilesManager()
        self.comms_method = glbl_cfg().get(['communication', 'method'])
        self.get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(
            {
                'cylc': self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_PASSPHRASE,
                    suite, content=True),
                'anon': NO_PASSPHRASE
            },
            algorithm=self.hash_algorithm)
        if self.comms_method == 'http':
            self.cert = None
            self.pkey = None
        else:  # if self.comms_method in [None, 'https']:
            try:
                self.cert = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_SSL_CERT, suite)
                self.pkey = self.srv_files_mgr.get_auth_item(
                    self.srv_files_mgr.FILE_BASE_SSL_PEM, suite)
            except SuiteServiceFileError:
                ERR.error("no HTTPS/OpenSSL support. Aborting...")
                raise CylcError("No HTTPS support. "
                                "Configure user's global.rc to use HTTP.")
        self.start()

    @cherrypy.expose
    def apiversion(self):
        """Return API version."""
        return str(self.API)

    @staticmethod
    def connect(schd):
        """Mount suite schedular object to the web server."""
        cherrypy.tree.mount(SuiteRuntimeService(schd), '/')
        # For back-compat with "scan"
        cherrypy.tree.mount(SuiteRuntimeService(schd), '/id')

    @staticmethod
    def disconnect(schd):
        """Disconnect obj from the web server."""
        del cherrypy.tree.apps['/%s/%s' % (schd.owner, schd.suite)]

    def get_port(self):
        """Return the web server port."""
        return self.port

    def shutdown(self):
        """Shutdown the web server."""
        if hasattr(self, "engine"):
            self.engine.exit()
            self.engine.block()

    def start(self):
        """Start quick web service."""
        # cherrypy.config["tools.encode.on"] = True
        # cherrypy.config["tools.encode.encoding"] = "utf-8"
        cherrypy.config["server.socket_host"] = get_host()
        cherrypy.config["engine.autoreload.on"] = False

        if self.comms_method == "https":
            # Setup SSL etc. Otherwise fail and exit.
            # Require connection method to be the same e.g HTTP/HTTPS matching.
            cherrypy.config['server.ssl_module'] = 'pyopenSSL'
            cherrypy.config['server.ssl_certificate'] = self.cert
            cherrypy.config['server.ssl_private_key'] = self.pkey

        cherrypy.config['log.screen'] = None
        key = binascii.hexlify(os.urandom(16))
        cherrypy.config.update({
            'tools.auth_digest.on': True,
            'tools.auth_digest.realm': self.suite,
            'tools.auth_digest.get_ha1': self.get_ha1,
            'tools.auth_digest.key': key,
            'tools.auth_digest.algorithm': self.hash_algorithm
        })
        cherrypy.tools.connect_log = cherrypy.Tool(
            'on_end_resource', self._report_connection_if_denied)
        cherrypy.config['tools.connect_log.on'] = True
        self.engine = cherrypy.engine
        for port in self.ok_ports:
            cherrypy.config["server.socket_port"] = port
            try:
                cherrypy.engine.start()
                cherrypy.engine.wait(cherrypy.engine.states.STARTED)
            except cherrypy.process.wspbus.ChannelFailures:
                if cylc.flags.debug:
                    traceback.print_exc()
                # We need to reinitialise the httpserver for each port attempt.
                cherrypy.server.httpserver = None
            else:
                if cherrypy.engine.state == cherrypy.engine.states.STARTED:
                    self.port = port
                    return
        raise Exception("No available ports")

    @staticmethod
    def _get_client_connection_denied():
        """Return whether a connection was denied."""
        if "Authorization" not in cherrypy.request.headers:
            # Probably just the initial HTTPS handshake.
            return False
        status = cherrypy.response.status
        if isinstance(status, basestring):
            return cherrypy.response.status.split()[0] in ["401", "403"]
        return cherrypy.response.status in [401, 403]

    def _report_connection_if_denied(self):
        """Log an (un?)successful connection attempt."""
        prog_name, user, host, uuid = _get_client_info()[1:]
        connection_denied = self._get_client_connection_denied()
        if connection_denied:
            LOG.warning(self.__class__.LOG_CONNECT_DENIED_TMPL % (
                user, host, prog_name, uuid))
Example #14
0
class RemoteJobHostManager(object):
    """Manage a remote job host."""

    _INSTANCE = None

    @classmethod
    def get_inst(cls):
        """Return a singleton instance of this class."""
        if cls._INSTANCE is None:
            cls._INSTANCE = cls()
        return cls._INSTANCE

    def __init__(self):
        self.initialised_hosts = {}  # user_at_host: should_unlink
        self.single_task_mode = False
        self.suite_srv_files_mgr = SuiteSrvFilesManager()

    def init_suite_run_dir(self, reg, user_at_host):
        """Initialise suite run dir on a user@host.

        Create SUITE_RUN_DIR/log/job/ if necessary.
        Install suite contact environment file.
        Install suite python modules.

        Raise RemoteJobHostInitError if initialisation cannot complete.

        """
        if "@" in user_at_host:
            owner, host = user_at_host.split("@", 1)
        else:
            owner, host = None, user_at_host
        if (
            (owner, host) in [(None, "localhost"), (USER, "localhost")]
            or host in self.initialised_hosts
            or self.single_task_mode
        ):
            return

        r_suite_run_dir = GLOBAL_CFG.get_derived_host_item(reg, "suite run directory", host, owner)
        r_log_job_dir = GLOBAL_CFG.get_derived_host_item(reg, "suite job log directory", host, owner)
        r_suite_srv_dir = os.path.join(r_suite_run_dir, self.suite_srv_files_mgr.DIR_BASE_SRV)

        # Create a UUID file in the service directory.
        # If remote host has the file in its service directory, we can assume
        # that the remote host has a shared file system with the suite host.
        ssh_tmpl = GLOBAL_CFG.get_host_item("remote shell template", host, owner)
        uuid_str = str(uuid4())
        uuid_fname = os.path.join(self.suite_srv_files_mgr.get_suite_srv_dir(reg), uuid_str)
        try:
            open(uuid_fname, "wb").close()
            proc = Popen(
                shlex.split(ssh_tmpl) + ["-n", user_at_host, "test", "-e", os.path.join(r_suite_srv_dir, uuid_str)],
                stdout=PIPE,
                stderr=PIPE,
            )
            if proc.wait() == 0:
                # Initialised, but no need to tidy up
                self.initialised_hosts[user_at_host] = False
                return
        finally:
            try:
                os.unlink(uuid_fname)
            except OSError:
                pass

        cmds = []
        # Command to create suite directory structure on remote host.
        cmds.append(
            shlex.split(ssh_tmpl) + ["-n", user_at_host, "mkdir", "-p", r_suite_run_dir, r_log_job_dir, r_suite_srv_dir]
        )
        # Command to copy contact and authentication files to remote host.
        # Note: no need to do this if task communication method is "poll".
        should_unlink = GLOBAL_CFG.get_host_item("task communication method", host, owner) != "poll"
        if should_unlink:
            scp_tmpl = GLOBAL_CFG.get_host_item("remote copy template", host, owner)
            cmds.append(
                shlex.split(scp_tmpl)
                + [
                    "-p",
                    self.suite_srv_files_mgr.get_contact_file(reg),
                    self.suite_srv_files_mgr.get_auth_item(self.suite_srv_files_mgr.FILE_BASE_PASSPHRASE, reg),
                    self.suite_srv_files_mgr.get_auth_item(self.suite_srv_files_mgr.FILE_BASE_SSL_CERT, reg),
                    user_at_host + ":" + r_suite_srv_dir + "/",
                ]
            )
        # Command to copy python library to remote host.
        suite_run_py = os.path.join(GLOBAL_CFG.get_derived_host_item(reg, "suite run directory"), "python")
        if os.path.isdir(suite_run_py):
            cmds.append(shlex.split(scp_tmpl) + ["-pr", suite_run_py, user_at_host + ":" + r_suite_run_dir + "/"])
        # Run commands in sequence.
        for cmd in cmds:
            proc = Popen(cmd, stdout=PIPE, stderr=PIPE)
            out, err = proc.communicate()
            if proc.wait():
                raise RemoteJobHostInitError(
                    RemoteJobHostInitError.MSG_INIT,
                    user_at_host,
                    " ".join([quote(item) for item in cmd]),
                    proc.returncode,
                    out,
                    err,
                )
        self.initialised_hosts[user_at_host] = should_unlink
        LOG.info("Initialised %s:%s" % (user_at_host, r_suite_run_dir))

    def unlink_suite_contact_files(self, reg):
        """Remove suite contact files from initialised hosts.

        This is called on shutdown, so we don't want anything to hang.
        Terminate any incomplete SSH commands after 10 seconds.
        """
        # Issue all SSH commands in parallel
        procs = {}
        for user_at_host, should_unlink in self.initialised_hosts.items():
            if not should_unlink:
                continue
            if "@" in user_at_host:
                owner, host = user_at_host.split("@", 1)
            else:
                owner, host = None, user_at_host
            ssh_tmpl = GLOBAL_CFG.get_host_item("remote shell template", host, owner)
            r_suite_contact_file = os.path.join(
                GLOBAL_CFG.get_derived_host_item(reg, "suite run directory", host, owner),
                SuiteSrvFilesManager.DIR_BASE_SRV,
                SuiteSrvFilesManager.FILE_BASE_CONTACT,
            )
            cmd = shlex.split(ssh_tmpl) + ["-n", user_at_host, "rm", "-f", r_suite_contact_file]
            procs[user_at_host] = (cmd, Popen(cmd, stdout=PIPE, stderr=PIPE))
        # Wait for commands to complete for a max of 10 seconds
        timeout = time() + 10.0
        while procs and time() < timeout:
            for user_at_host, (cmd, proc) in procs.items():
                if not proc.poll():
                    continue
                del procs[user_at_host]
                out, err = proc.communicate()
                if proc.wait():
                    ERR.warning(
                        RemoteJobHostInitError(
                            RemoteJobHostInitError.MSG_TIDY,
                            user_at_host,
                            " ".join([quote(item) for item in cmd]),
                            proc.returncode,
                            out,
                            err,
                        )
                    )
        # Terminate any remaining commands
        for user_at_host, (cmd, proc) in procs.items():
            try:
                proc.terminate()
            except OSError:
                pass
            out, err = proc.communicate()
            proc.wait()
            ERR.warning(
                RemoteJobHostInitError(
                    RemoteJobHostInitError.MSG_TIDY,
                    user_at_host,
                    " ".join([quote(item) for item in cmd]),
                    proc.returncode,
                    out,
                    err,
                )
            )