class BaseCommsClient(object): """Base class for client-side suite object interfaces.""" ACCESS_DESCRIPTION = 'private' METHOD = 'POST' METHOD_POST = 'POST' METHOD_GET = 'GET' def __init__(self, suite, owner=USER, host=None, timeout=None, port=None, db=None, my_uuid=None, print_uuid=False): self.suite = suite self.host = host self.owner = owner if timeout is not None: timeout = float(timeout) self.timeout = timeout self.port = port self.my_uuid = my_uuid or uuid4() if print_uuid: print >> sys.stderr, '%s' % self.my_uuid self.reg_db = RegistrationDB(db) self.prog_name = os.path.basename(sys.argv[0]) def call_server_func(self, category, fname, **fargs): """Call server_object.fname(*fargs, **fargs).""" if self.host is None or self.port is None: self._load_contact_info() handle_proxies() payload = fargs.pop("payload", None) method = fargs.pop("method", self.METHOD) host = self.host if not self.host.split(".")[0].isdigit(): host = self.host.split(".")[0] if host == "localhost": host = get_hostname().split(".")[0] url = 'https://%s:%s/%s/%s' % ( host, self.port, category, fname ) if fargs: import urllib params = urllib.urlencode(fargs, doseq=True) url += "?" + params return self.get_data_from_url(url, payload, method=method) def get_data_from_url(self, url, json_data, method=None): requests_ok = True try: import requests except ImportError: requests_ok = False else: version = [int(_) for _ in requests.__version__.split(".")] if version < [2, 4, 2]: requests_ok = False if requests_ok: return self.get_data_from_url_with_requests( url, json_data, method=method) return self.get_data_from_url_with_urllib2( url, json_data, method=method) def get_data_from_url_with_requests(self, url, json_data, method=None): import requests username, password = self._get_auth() auth = requests.auth.HTTPDigestAuth(username, password) if not hasattr(self, "session"): self.session = requests.Session() if method is None: method = self.METHOD if method == self.METHOD_POST: session_method = self.session.post else: session_method = self.session.get try: ret = session_method( url, json=json_data, verify=self._get_verify(), proxies={}, headers=self._get_headers(), auth=auth, timeout=self.timeout ) except requests.exceptions.SSLError as exc: if "unknown protocol" in str(exc) and url.startswith("https:"): # Server is using http rather than https, for some reason. sys.stderr.write(WARNING_NO_HTTPS_SUPPORT.format(exc)) return self.get_data_from_url_with_requests( url.replace("https:", "http:", 1), json_data) if cylc.flags.debug: import traceback traceback.print_exc() raise ConnectionError(url, exc) except requests.exceptions.Timeout as exc: if cylc.flags.debug: import traceback traceback.print_exc() raise ConnectionTimeout(url, exc) except requests.exceptions.RequestException as exc: if cylc.flags.debug: import traceback traceback.print_exc() raise ConnectionError(url, exc) if ret.status_code == 401: raise ConnectionDeniedError(url, self.prog_name, self.ACCESS_DESCRIPTION) if ret.status_code >= 400: from cylc.network.https.util import get_exception_from_html exception_text = get_exception_from_html(ret.text) if exception_text: sys.stderr.write(exception_text) else: sys.stderr.write(ret.text) try: ret.raise_for_status() except requests.exceptions.HTTPError as exc: if cylc.flags.debug: import traceback traceback.print_exc() raise ConnectionError(url, exc) try: return ret.json() except ValueError: return ret.text def get_data_from_url_with_urllib2(self, url, json_data, method=None): import json import urllib2 import ssl if hasattr(ssl, '_create_unverified_context'): ssl._create_default_https_context = ssl._create_unverified_context if method is None: method = self.METHOD orig_json_data = json_data username, password = self._get_auth() auth_manager = urllib2.HTTPPasswordMgrWithDefaultRealm() auth_manager.add_password(None, url, username, password) auth = urllib2.HTTPDigestAuthHandler(auth_manager) opener = urllib2.build_opener(auth, urllib2.HTTPSHandler()) headers_list = self._get_headers().items() if json_data: json_data = json.dumps(json_data) headers_list.append(('Accept', 'application/json')) json_headers = {'Content-Type': 'application/json', 'Content-Length': len(json_data)} else: json_data = None json_headers = {'Content-Length': 0} opener.addheaders = headers_list req = urllib2.Request(url, json_data, json_headers) # This is an unpleasant monkey patch, but there isn't an alternative. # urllib2 uses POST iff there is a data payload, but that is not the # correct criterion. The difference is basically that POST changes # server state and GET doesn't. req.get_method = lambda: method try: response = opener.open(req, timeout=self.timeout) except urllib2.URLError as exc: if "unknown protocol" in str(exc) and url.startswith("https:"): # Server is using http rather than https, for some reason. sys.stderr.write(WARNING_NO_HTTPS_SUPPORT.format(exc)) return self.get_data_from_url_with_urllib2( url.replace("https:", "http:", 1), orig_json_data) if cylc.flags.debug: import traceback traceback.print_exc() if "timed out" in str(exc): raise ConnectionTimeout(url, exc) else: raise ConnectionError(url, exc) except Exception as exc: if cylc.flags.debug: import traceback traceback.print_exc() raise ConnectionError(url, exc) if response.getcode() == 401: raise ConnectionDeniedError(url, self.prog_name, self.ACCESS_DESCRIPTION) response_text = response.read() if response.getcode() >= 400: from cylc.network.https.util import get_exception_from_html exception_text = get_exception_from_html(response_text) if exception_text: sys.stderr.write(exception_text) else: sys.stderr.write(response_text) raise ConnectionError(url, "%s HTTP return code" % response.getcode()) try: return json.loads(response_text) except ValueError: return response_text def _get_auth(self): """Return a user/password Digest Auth.""" self.pphrase = self.reg_db.load_passphrase( self.suite, self.owner, self.host) if self.pphrase: self.reg_db.cache_passphrase( self.suite, self.owner, self.host, self.pphrase) if self.pphrase is None: return 'anon', NO_PASSPHRASE return 'cylc', self.pphrase def _get_headers(self): """Return HTTP headers identifying the client.""" user_agent_string = ( "cylc/%s prog_name/%s uuid/%s" % ( CYLC_VERSION, self.prog_name, self.my_uuid ) ) auth_info = "%s@%s" % (USER, get_hostname()) return {"User-Agent": user_agent_string, "From": auth_info} def _get_verify(self): """Return the server certificate if possible.""" if not hasattr(self, "server_cert"): try: self.server_cert = self.reg_db.load_item( self.suite, self.owner, self.host, "certificate") except PassphraseError: return False return self.server_cert def _load_contact_info(self): """Obtain URL info. Determine host and port using content in port file, unless already specified. """ if self.host and self.port: return if 'CYLC_SUITE_RUN_DIR' in os.environ: # Looks like we are in a running task job, so we should be able to # use "cylc-suite-env" file under the suite running directory try: suite_env = CylcSuiteEnv.load( self.suite, os.environ['CYLC_SUITE_RUN_DIR']) except CylcSuiteEnvLoadError: if cylc.flags.debug: import traceback traceback.print_exc() else: self.host = suite_env.suite_host self.port = suite_env.suite_port self.owner = suite_env.suite_owner if self.host is None or self.port is None: self._load_port_file() def _load_port_file(self): """Load port, host, etc from port file.""" # GLOBAL_CFG is expensive to import, so only load on demand from cylc.cfgspec.globalcfg import GLOBAL_CFG port_file_path = os.path.join( GLOBAL_CFG.get(['communication', 'ports directory']), self.suite) out = "" if is_remote_host(self.host) or is_remote_user(self.owner): # Only load these modules on demand, as they may be expensive import shlex from subprocess import Popen, PIPE ssh_tmpl = str(GLOBAL_CFG.get_host_item( 'remote shell template', self.host, self.owner)) ssh_tmpl = ssh_tmpl.replace(' %s', '') user_at_host = '' if self.owner: user_at_host = self.owner + '@' if self.host: user_at_host += self.host else: user_at_host += 'localhost' r_port_file_path = port_file_path.replace( os.environ['HOME'], '$HOME') command = shlex.split(ssh_tmpl) + [ user_at_host, 'cat', r_port_file_path] proc = Popen(command, stdout=PIPE, stderr=PIPE) out, err = proc.communicate() ret_code = proc.wait() if ret_code: if cylc.flags.debug: print >> sys.stderr, { "code": ret_code, "command": command, "stdout": out, "stderr": err} if self.port is None: raise PortFileError( "Port file '%s:%s' not found - suite not running?." % (user_at_host, r_port_file_path)) else: try: out = open(port_file_path).read() except IOError: if self.port is None: raise PortFileError( "Port file '%s' not found - suite not running?." % (port_file_path)) lines = out.splitlines() if self.port is None: try: self.port = int(lines[0]) except (IndexError, ValueError): raise PortFileError( "ERROR, bad content in port file: %s" % port_file_path) if self.host is None: if len(lines) >= 2: self.host = lines[1].strip() else: self.host = get_hostname() def reset(self, *args, **kwargs): pass def signout(self, *args, **kwargs): pass
class CommsDaemon(object): """Wrap HTTPS daemon for a suite.""" def __init__(self, suite, suite_dir): # Suite only needed for back-compat with old clients (see below): self.suite = suite # Figure out the ports we are allowed to use. base_port = GLOBAL_CFG.get(["communication", "base port"]) max_ports = GLOBAL_CFG.get(["communication", "maximum number of ports"]) self.ok_ports = range(int(base_port), int(base_port) + int(max_ports)) random.shuffle(self.ok_ports) comms_options = GLOBAL_CFG.get(["communication", "options"]) # HTTP Digest Auth uses MD5 - pretty secure in this use case. # Extending it with extra algorithms is allowed, but won't be # supported by most browsers. requests and urllib2 are OK though. self.hash_algorithm = "MD5" if "SHA1" in comms_options: # Note 'SHA' rather than 'SHA1'. self.hash_algorithm = "SHA" self.reg_db = RegistrationDB() try: self.cert = self.reg_db.load_item(suite, USER, None, "certificate", create_ok=True) self.pkey = self.reg_db.load_item(suite, USER, None, "private_key", create_ok=True) except PassphraseError: # No OpenSSL installed. self.cert = None self.pkey = None self.suite = suite passphrase = self.reg_db.load_passphrase(suite, USER, None) userpassdict = {"cylc": passphrase, "anon": NO_PASSPHRASE} get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(userpassdict, algorithm=self.hash_algorithm) self.get_ha1 = get_ha1 del passphrase del userpassdict self.client_reporter = CommsClientReporter.get_inst() self.start() def start(self): _ws_init(self) def shutdown(self): """Shutdown the daemon.""" if hasattr(self, "engine"): self.engine.exit() self.engine.block() def connect(self, obj, name): """Connect obj and name to the daemon.""" import cherrypy cherrypy.tree.mount(obj, "/" + name) def disconnect(self, obj): """Disconnect obj from the daemon.""" pass def get_port(self): """Return the daemon port.""" return self.port def report_connection_if_denied(self): self.client_reporter.report_connection_if_denied()