class Authenticator(object): _authenticators = {} @classmethod def instance(cls, public_keys_dir): '''Please avoid create multi instance''' if public_keys_dir in cls._authenticators: return cls._authenticators[public_keys_dir] new_instance = cls(public_keys_dir) cls._authenticators[public_keys_dir] = new_instance return new_instance def __init__(self, public_keys_dir): self._auth = ThreadAuthenticator(zmq.Context.instance()) self._auth.start() self._auth.allow('*') self._auth.configure_curve(domain='*', location=public_keys_dir) def set_server_key(self, zmq_socket, server_secret_key_path): '''must call before bind''' load_and_set_key(zmq_socket, server_secret_key_path) zmq_socket.curve_server = True def set_client_key(self, zmq_socket, client_secret_key_path, server_public_key_path): '''must call before bind''' load_and_set_key(zmq_socket, client_secret_key_path) server_public, _ = zmq.auth.load_certificate(server_public_key_path) zmq_socket.curve_serverkey = server_public def stop(self): self._auth.stop()
def main(): localhost = socket_m.getfqdn() port = "5556" # ip = "*" ip = socket_m.gethostbyaddr(localhost)[2][0] context = zmq.Context() socket = context.socket(zmq.PULL) socket.zap_domain = b'global' socket.bind("tcp://" + ip + ":%s" % port) auth = ThreadAuthenticator(context) host = localhost # host = asap3-p00 whitelist = socket_m.gethostbyaddr(host)[2][0] # whitelist = None auth.start() if whitelist is None: auth.auth = None else: auth.allow(whitelist) try: while True: message = socket.recv_multipart() print("received reply ", message) except KeyboardInterrupt: pass finally: auth.stop()
def run_mdp_broker(): args = docopt("""Usage: mdp-broker [options] <config> Options: -h --help show this help message and exit -s --secure generate (and print) client & broker keys for a secure server """) global log _setup_logging(args['<config>']) log = logging.getLogger(__name__) cp = ConfigParser() cp.read(args['<config>']) # Parse settings a bit raw = dict( (option, cp.get('mdp-broker', option)) for option in cp.options('mdp-broker')) s = SettingsSchema().to_python(raw) if args['--secure']: broker_key = Key.generate() client_key = Key.generate() s['key'] = dict( broker=broker_key, client=client_key) log.info('Auto-generated keys: %s_%s_%s', broker_key.public, client_key.public, client_key.secret) log.info(' broker.public: %s', broker_key.public) log.info(' client.public: %s', client_key.public) log.info(' client.secret: %s', client_key.secret) if s['key']: log.info('Starting secure mdp-broker on %s', s['uri']) auth = ThreadAuthenticator() auth.start() auth.thread.authenticator.certs['*'] = { s['key']['client'].public: 'OK'} broker = SecureMajorDomoBroker(s['key']['broker'], s['uri']) else: log.info('Starting mdp-broker on %s', s['uri']) broker = MajorDomoBroker(s['uri']) try: broker.serve_forever() except: auth.stop() raise
def main(): auth = ThreadAuthenticator(zmq.Context.instance()) auth.start() auth.allow('127.0.0.1') # Tell the authenticator how to handle CURVE requests auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) key = Key.load('example/broker.key_secret') broker = SecureMajorDomoBroker(key, sys.argv[1]) try: broker.serve_forever() except KeyboardInterrupt: auth.stop() raise
def run_mdp_broker(): args = docopt("""Usage: mdp-broker [options] <config> Options: -h --help show this help message and exit -s --secure generate (and print) client & broker keys for a secure server """) global log _setup_logging(args['<config>']) log = logging.getLogger(__name__) cp = ConfigParser() cp.read(args['<config>']) # Parse settings a bit raw = dict((option, cp.get('mdp-broker', option)) for option in cp.options('mdp-broker')) s = SettingsSchema().to_python(raw) if args['--secure']: broker_key = Key.generate() client_key = Key.generate() s['key'] = dict(broker=broker_key, client=client_key) log.info('Auto-generated keys: %s_%s_%s', broker_key.public, client_key.public, client_key.secret) log.info(' broker.public: %s', broker_key.public) log.info(' client.public: %s', client_key.public) log.info(' client.secret: %s', client_key.secret) if s['key']: log.info('Starting secure mdp-broker on %s', s['uri']) auth = ThreadAuthenticator() auth.start() auth.thread.authenticator.certs['*'] = { s['key']['client'].public: 'OK' } broker = SecureMajorDomoBroker(s['key']['broker'], s['uri']) else: log.info('Starting mdp-broker on %s', s['uri']) broker = MajorDomoBroker(s['uri']) try: broker.serve_forever() except: auth.stop() raise
def main(): port = "5556" socket_ip = "*" # ip = socket.getfqdn() context = zmq.Context() auth = ThreadAuthenticator(context) auth.start() whitelist = [socket.getfqdn()] for host in whitelist: hostname, tmp, ip = socket.gethostbyaddr(host) auth.allow(ip[0]) zmq_socket = context.socket(zmq.PUSH) zmq_socket.zap_domain = b'global' zmq_socket.bind("tcp://" + socket_ip + ":%s" % port) try: for i in range(5): message = ["World"] print("Send: ", message) res = zmq_socket.send_multipart(message, copy=False, track=True) if res.done: print("res: done") else: print("res: waiting") res.wait() print("res: waiting...") print("sleeping...") if i == 1: auth.stop() zmq_socket.close(0) auth.start() # ip = socket.gethostbyaddr(socket.getfqdn())[2] # auth.allow(ip[0]) ip = socket.gethostbyaddr(socket.getfqdn())[2] auth.deny(ip[0]) zmq_socket = context.socket(zmq.PUSH) zmq_socket.zap_domain = b'global' zmq_socket.bind("tcp://" + socket_ip + ":%s" % port) time.sleep(1) print("sleeping...done") i += 1 finally: auth.stop()
def setup_auth(): global _auth assert _options is not None auth = _options.get('auth',None) if auth is None: return base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..')) try: _auth = ThreadAuthenticator(_zctx) _auth.start() whitelist = auth.get('whitelist',None) if whitelist is not None: _auth.allow(whitelist) public_path = auth.get('public_key_dir','public_keys') _auth.configure_curve(domain='*',location=getExistsPath(base_dir,public_path)) private_dir = getExistsPath(base_dir,auth.get('private_key_dir','private_keys')) private_key = os.path.join(private_dir,auth.get('private_key_file','server.key_secret')) server_public,server_private = zmq.auth.load_certificate(private_key) _sock.curve_secretkey = server_private _sock.curve_publickey = server_public _sock.curve_server = True except: _auth.stop() _auth = None
def run(self): self.set_status("Server Startup") self.set_status("Creating zmq Contexts",1) serverctx = zmq.Context() self.set_status("Starting zmq ThreadedAuthenticator",1) #serverauth = zmq.auth.ThreadedAuthenticator(serverctx) serverauth = ThreadAuthenticator(serverctx) serverauth.start() with taco.globals.settings_lock: bindip = taco.globals.settings["Application IP"] bindport = taco.globals.settings["Application Port"] localuuid = taco.globals.settings["Local UUID"] publicdir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/public/")) privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/private/")) self.set_status("Configuring Curve to use publickey dir:" + publicdir) serverauth.configure_curve(domain='*', location=publicdir) #auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) self.set_status("Creating Server Context",1) server = serverctx.socket(zmq.REP) server.setsockopt(zmq.LINGER, 0) self.set_status("Loading Server Certs",1) server_public, server_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-server.key_secret"))) server.curve_secretkey = server_secret server.curve_publickey = server_public server.curve_server = True if bindip == "0.0.0.0": bindip ="*" self.set_status("Server is now listening for encrypted ZMQ connections @ "+ "tcp://" + bindip +":" + str(bindport)) server.bind("tcp://" + bindip +":" + str(bindport)) poller = zmq.Poller() poller.register(server, zmq.POLLIN|zmq.POLLOUT) while not self.stop.is_set(): socks = dict(poller.poll(200)) if server in socks and socks[server] == zmq.POLLIN: #self.set_status("Getting a request") data = server.recv() with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data)) (client_uuid,reply) = taco.commands.Proccess_Request(data) if client_uuid!="0": self.set_client_last_request(client_uuid) socks = dict(poller.poll(10)) if server in socks and socks[server] == zmq.POLLOUT: #self.set_status("Replying to a request") with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(reply)) server.send(reply) self.set_status("Stopping zmq server with 0 second linger") server.close(0) self.set_status("Stopping zmq ThreadedAuthenticator") serverauth.stop() serverctx.term() self.set_status("Server Exit")
class MispZmq: message_count = 0 publish_count = 0 monitor_thread = None auth = None socket = None pidfile = None r: redis.StrictRedis namespace: str def __init__(self): self._logger = logging.getLogger() self.tmp_location = Path(__file__).parent.parent / "tmp" self.pidfile = self.tmp_location / "mispzmq.pid" if self.pidfile.exists(): with open(self.pidfile.as_posix()) as f: pid = f.read() if check_pid(pid): raise Exception( "mispzmq already running on PID {}".format(pid)) else: # Cleanup self.pidfile.unlink() if (self.tmp_location / "mispzmq_settings.json").exists(): self._setup() else: raise Exception("The settings file is missing.") def _setup(self): with open((self.tmp_location / "mispzmq_settings.json").as_posix()) as settings_file: self.settings = json.load(settings_file) self.namespace = self.settings["redis_namespace"] self.r = redis.StrictRedis(host=self.settings["redis_host"], db=self.settings["redis_database"], password=self.settings["redis_password"], port=self.settings["redis_port"], decode_responses=True) self.timestamp_settings = time.time() self._logger.debug("Connected to Redis {}:{}/{}".format( self.settings["redis_host"], self.settings["redis_port"], self.settings["redis_database"])) def _setup_zmq(self): context = zmq.Context() if "username" in self.settings and self.settings["username"]: if "password" not in self.settings or not self.settings["password"]: raise Exception( "When username is set, password cannot be empty.") self.auth = ThreadAuthenticator(context) self.auth.start() self.auth.configure_plain(domain="*", passwords={ self.settings["username"]: self.settings["password"] }) else: if self.auth: self.auth.stop() self.auth = None self.socket = context.socket(zmq.PUB) if self.settings["username"]: self.socket.plain_server = True # must come before bind self.socket.bind("tcp://{}:{}".format(self.settings["host"], self.settings["port"])) self._logger.debug("ZMQ listening on tcp://{}:{}".format( self.settings["host"], self.settings["port"])) if self._logger.isEnabledFor(logging.DEBUG): monitor = self.socket.get_monitor_socket() self.monitor_thread = threading.Thread(target=event_monitor, args=(monitor, self._logger)) self.monitor_thread.start() else: if self.monitor_thread: self.socket.disable_monitor() self.monitor_thread = None def _handle_command(self, command): if command == "kill": self._logger.info("Kill command received, shutting down.") self.clean() sys.exit() elif command == "reload": self._logger.info( "Reload command received, reloading settings from file.") self._setup() self._setup_zmq() elif command == "status": self._logger.info( "Status command received, responding with latest stats.") self.r.delete("{}:status".format(self.namespace)) self.r.lpush( "{}:status".format(self.namespace), json.dumps({ "timestamp": time.time(), "timestampSettings": self.timestamp_settings, "publishCount": self.publish_count, "messageCount": self.message_count })) else: self._logger.warning( "Received invalid command '{}'.".format(command)) def _create_pid_file(self): with open(self.pidfile.as_posix(), "w") as f: f.write(str(os.getpid())) def _pub_message(self, topic, data): self.socket.send_string("{} {}".format(topic, data)) def clean(self): if self.monitor_thread: self.socket.disable_monitor() if self.auth: self.auth.stop() if self.socket: self.socket.close() if self.pidfile: self.pidfile.unlink() def main(self): self._create_pid_file() self._setup_zmq() time.sleep(1) status_array = [ "And when you're dead I will be still alive.", "And believe me I am still alive.", "I'm doing science and I'm still alive.", "I feel FANTASTIC and I'm still alive.", "While you're dying I'll be still alive.", ] topics = [ "misp_json", "misp_json_event", "misp_json_attribute", "misp_json_sighting", "misp_json_organisation", "misp_json_user", "misp_json_conversation", "misp_json_object", "misp_json_object_reference", "misp_json_audit", "misp_json_tag", "misp_json_warninglist" ] lists = ["{}:command".format(self.namespace)] for topic in topics: lists.append("{}:data:{}".format(self.namespace, topic)) while True: data = self.r.blpop(lists, timeout=10) if data is None: # redis timeout expired current_time = int(time.time()) time_delta = current_time - int(self.timestamp_settings) status_entry = int(time_delta / 10 % 5) status_message = { "status": status_array[status_entry], "uptime": current_time - int(self.timestamp_settings) } self._pub_message("misp_json_self", json.dumps(status_message)) self._logger.debug( "No message received for 10 seconds, sending ZMQ status message." ) else: key, value = data key = key.replace("{}:".format(self.namespace), "") if key == "command": self._handle_command(value) elif key.startswith("data:"): topic = key.split(":")[1] self._logger.debug( "Received data for topic '{}', sending to ZMQ.".format( topic)) self._pub_message(topic, value) self.message_count += 1 if topic == "misp_json": self.publish_count += 1 else: self._logger.warning( "Received invalid message '{}'.".format(key))
class Actor(object): '''The actor class implements all the management and control functions over its components :param gModel: the JSON-based dictionary holding the model for the app this actor belongs to. :type gModel: dict :param gModelName: the name of the top-level model for the app :type gModelName: str :param aName: name of the actor. It is an index into the gModel that points to the part of the model specific to the actor :type aName: str :param sysArgv: list of arguments for the actor: -key1 value1 -key2 value2 ... :type list: ''' def __init__(self, gModel, gModelName, aName, sysArgv): ''' Constructor ''' self.logger = logging.getLogger(__name__) self.inst_ = self self.appName = gModel["name"] self.modelName = gModelName self.name = aName self.pid = os.getpid() self.uuid = None self.setupIfaces() # Assumption : pid is a 4 byte int self.actorID = ipaddress.IPv4Address( self.globalHost).packed + self.pid.to_bytes(4, 'big') self.suffix = "" if aName not in gModel["actors"]: raise BuildError('Actor "%s" unknown' % aName) self.model = gModel["actors"][ aName] # Fetch the relevant content from the model self.INT_RE = re.compile(r"^[-]?\d+$") self.parseParams(sysArgv) # Use czmq's context czmq_ctx = Zsys.init() self.context = zmq.Context.shadow(czmq_ctx.value) Zsys.handler_reset() # Reset previous signal handler # Context for app sockets self.appContext = zmq.Context() if Config.SECURITY: (self.public_key, self.private_key) = zmq.auth.load_certificate(const.appCertFile) hosts = ['127.0.0.1'] try: with open(const.appDescFile, 'r') as f: content = yaml.load(f) hosts += content.hosts except: pass self.auth = ThreadAuthenticator(self.appContext) self.auth.start() self.auth.allow(*hosts) self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) else: (self.public_key, self.private_key) = (None, None) self.auth = None self.appContext = self.context try: if os.path.isfile(const.logConfFile) and os.access( const.logConfFile, os.R_OK): spdlog_setup.from_file(const.logConfFile) except Exception as e: self.logger.error("error while configuring componentLogger: %s" % repr(e)) messages = gModel[ "messages"] # Global message types (global on the network) self.messageNames = [] for messageSpec in messages: self.messageNames.append(messageSpec["name"]) locals_ = self.model[ "locals"] # Local message types (local to the host) self.localNames = [] for messageSpec in locals_: self.localNames.append(messageSpec["type"]) internals = self.model[ "internals"] # Internal message types (internal to the actor process) self.internalNames = [] for messageSpec in internals: self.internalNames.append(messageSpec["type"]) self.components = {} instSpecs = self.model["instances"] compSpecs = gModel["components"] ioSpecs = gModel["devices"] for instName in instSpecs: # Create the component instances: the 'parts' instSpec = instSpecs[instName] instType = instSpec['type'] if instType in compSpecs: typeSpec = compSpecs[instType] ioComp = False elif instType in ioSpecs: typeSpec = ioSpecs[instType] ioComp = True else: raise BuildError( 'Component type "%s" for instance "%s" is undefined' % (instType, instName)) instFormals = typeSpec['formals'] instActuals = instSpec['actuals'] instArgs = self.buildInstArgs(instName, instFormals, instActuals) # Check whether the component is C++ component ccComponentFile = 'lib' + instType.lower() + '.so' ccComp = os.path.isfile(ccComponentFile) try: if not ioComp: if ccComp: modObj = importlib.import_module('lib' + instType.lower()) self.components[instName] = modObj.create_component_py( self, self.model, typeSpec, instName, instType, instArgs, self.appName, self.name) else: self.components[instName] = Part( self, typeSpec, instName, instType, instArgs) else: self.components[instName] = Peripheral( self, typeSpec, instName, instType, instArgs) except Exception as e: traceback.print_exc() self.logger.error("Error while constructing part '%s.%s': %s" % (instType, instName, str(e))) def getParameterValueType(self, param, defaultType): ''' Infer the type of a parameter from its value unless a default type is provided. \ In the latter case the parameter's value is converted to that type. :param param: a parameter value :type param: one of bool,int,float,str :param defaultType: :type defaultType: one of bool,int,float,str :return: a pair (value,type) :rtype: tuple ''' paramValue, paramType = None, None if defaultType != None: if defaultType == str: paramValue, paramType = param, str elif defaultType == int: paramValue, paramType = int(param), int elif defaultType == float: paramValue, paramType = float(param), float elif defaultType == bool: paramType = bool paramValue = False if param == "False" else True if param == "True" else None paramValue, paramType = bool(param), float else: if param == 'True': paramValue, paramType = True, bool elif param == 'False': paramValue, paramType = True, bool elif self.INT_RE.match(param) is not None: paramValue, paramType = int(param), int else: try: paramValue, paramType = float(param), float except: paramValue, paramType = str(param), str return (paramValue, paramType) def parseParams(self, sysArgv): '''Parse actor arguments from the command line Compares the actual arguments to the formal arguments (from the model) and fills out the local parameter table accordingly. Generates a warning on extra arguments and raises an exception on required but missing ones. ''' self.params = {} formals = self.model["formals"] optList = [] for formal in formals: key = formal["name"] default = None if "default" not in formal else formal["default"] self.params[key] = default optList.append("%s=" % key) try: opts, _args = getopt.getopt(sysArgv, '', optList) except: self.logger.info("Error parsing actor options %s" % str(sysArgv)) return for opt in opts: optName2, optValue = opt optName = optName2[2:] # Drop two leading dashes if optName in self.params: defaultType = None if self.params[optName] == None else type( self.params[optName]) paramValue, paramType = self.getParameterValueType( optValue, defaultType) if self.params[optName] != None: if paramType != type(self.params[optName]): raise BuildError( "Type of default value does not match type of argument %s" % str((optName, optValue))) self.params[optName] = paramValue else: self.logger.info("Unknown argument %s - ignored" % optName) for param in self.params: if self.params[param] == None: raise BuildError("Required parameter %s missing" % param) def buildInstArgs(self, instName, formals, actuals): args = {} for formal in formals: argName = formal['name'] argValue = None actual = next( (actual for actual in actuals if actual['name'] == argName), None) defaultValue = None if 'default' in formal: defaultValue = formal['default'] if actual != None: assert (actual['name'] == argName) if 'param' in actual: paramName = actual['param'] if paramName in self.params: argValue = self.params[paramName] else: raise BuildError( "Unspecified parameter %s referenced in %s" % (paramName, instName)) elif 'value' in actual: argValue = actual['value'] else: raise BuildError("Actual parameter %s has no value" % argName) elif defaultValue != None: argValue = defaultValue else: raise BuildError("Argument %s in %s has no defined value" % (argName, instName)) args[argName] = argValue return args def isLocalMessage(self, msgTypeName): '''Return True if the message type is local ''' return msgTypeName in self.localNames def isInnerMessage(self, msgTypeName): '''Return True if the message type is internal ''' return msgTypeName in self.internalNames def getLocalIface(self): '''Return the IP address of the host-local network interface (usually 127.0.0.1) ''' return self.localHost def getGlobalIface(self): '''Return the IP address of the global network interface ''' return self.globalHost def getActorName(self): '''Return the name of this actor (as defined in the app model) ''' return self.name def getAppName(self): '''Return the name of the app this actor belongs to ''' return self.appName def getActorID(self): '''Returns an ID for this actor. The actor's id constructed from the host's IP address the actor's process id. The id is unique for a given host and actor run. ''' return self.actorID def setUUID(self, uuid): '''Sets the UUID for this actor. The UUID is dynamically generated (by the peer-to-peer network system) and is unique. ''' self.uuid = uuid def getUUID(self): '''Return the UUID for this actor. ''' return self.uuid def setupIfaces(self): '''Find the IP addresses of the (host-)local and network(-global) interfaces ''' (globalIPs, globalMACs, _globalNames, localIP) = getNetworkInterfaces() try: assert len(globalIPs) > 0 and len(globalMACs) > 0 except: self.logger.error("Error: no active network interface") raise globalIP = globalIPs[0] globalMAC = globalMACs[0] self.localHost = localIP self.globalHost = globalIP self.macAddress = globalMAC def setup(self): '''Perform a setup operation on the actor, after the initial construction but before the activation of parts ''' self.logger.info("setup") self.suffix = self.macAddress self.disco = DiscoClient(self, self.suffix) self.disco.start() # Start the discovery service client self.disco.registerApp( ) # Register this actor with the discovery service self.logger.info("actor registered with disco") self.deplc = DeplClient(self, self.suffix) self.deplc.start() ok = self.deplc.registerApp() self.logger.info("actor %s registered with depl" % ("is" if ok else "is not")) self.controls = {} self.controlMap = {} for inst in self.components: comp = self.components[inst] control = self.context.socket(zmq.PAIR) control.bind('inproc://part_' + inst + '_control') self.controls[inst] = control self.controlMap[id(control)] = comp if isinstance(comp, Part): self.components[inst].setup(control) else: self.components[inst].setup() def registerEndpoint(self, bundle): ''' Relay the endpoint registration message to the discovery service client ''' self.logger.info("registerEndpoint") result = self.disco.registerEndpoint(bundle) for res in result: (partName, portName, host, port) = res self.updatePart(partName, portName, host, port) def registerDevice(self, bundle): '''Relay the device registration message to the device interface service client ''' typeName, args = bundle msg = (self.appName, self.modelName, typeName, args) result = self.deplc.registerDevice(msg) return result def unregisterDevice(self, bundle): '''Relay the device unregistration message to the device interface service client ''' typeName, = bundle msg = (self.appName, self.modelName, typeName) result = self.deplc.unregisterDevice(msg) return result def activate(self): '''Activate the parts ''' self.logger.info("activate") for inst in self.components: self.components[inst].activate() def deactivate(self): '''Deactivate the parts ''' self.logger.info("deactivate") for inst in self.components: self.components[inst].deactivate() def recvChannelMessages(self, channel): '''Collect all messages from the channel queue and return them in a list ''' msgs = [] while True: try: msg = channel.recv(flags=zmq.NOBLOCK) msgs.append(msg) except zmq.Again: break return msgs def start(self): ''' Start and operate the actor (infinite polling loop) ''' self.logger.info("starting") self.discoChannel = self.disco.channel # Private channel to the discovery service self.deplChannel = self.deplc.channel self.poller = zmq.Poller() # Set up the poller self.poller.register(self.deplChannel, zmq.POLLIN) self.poller.register(self.discoChannel, zmq.POLLIN) for control in self.controls: self.poller.register(self.controls[control], zmq.POLLIN) while 1: sockets = dict(self.poller.poll()) if self.discoChannel in sockets: # If there is a message from a service, handle it msgs = self.recvChannelMessages(self.discoChannel) for msg in msgs: self.handleServiceUpdate( msg) # Handle message from disco service del sockets[self.discoChannel] elif self.deplChannel in sockets: msgs = self.recvChannelMessages(self.deplChannel) for msg in msgs: self.handleDeplMessage( msg) # Handle message from depl service del sockets[self.deplChannel] else: # Handle messages from the components. toDelete = [] for s in sockets: if s in self.controls.values(): part = self.controlMap[id(s)] msg = s.recv_pyobj( ) # receive python object from component self.handleEventReport(part, msg) # Report event toDelete += [s] for s in toDelete: del sockets[s] def handleServiceUpdate(self, msgBytes): ''' Handle a service update message from the discovery service ''' msgUpd = disco_capnp.DiscoUpd.from_bytes( msgBytes) # Parse the incoming message which = msgUpd.which() if which == 'portUpdate': msg = msgUpd.portUpdate client = msg.client actorHost = client.actorHost assert actorHost == self.globalHost # It has to be addressed to this actor actorName = client.actorName assert actorName == self.name instanceName = client.instanceName assert instanceName in self.components # It has to be for a part of this actor portName = client.portName scope = msg.scope socket = msg.socket host = socket.host port = socket.port if scope == "local": assert host == self.localHost self.updatePart(instanceName, portName, host, port) # Update the selected part def updatePart(self, instanceName, portName, host, port): ''' Ask a part to update itself ''' self.logger.info("updatePart %s" % str( (instanceName, portName, host, port))) part = self.components[instanceName] part.handlePortUpdate(portName, host, port) def handleDeplMessage(self, msgBytes): ''' Handle a message from the deployment service ''' msgUpd = deplo_capnp.DeplCmd.from_bytes( msgBytes) # Parse the incoming message which = msgUpd.which() if which == 'resourceMsg': what = msgUpd.resourceMsg.which() if what == 'resCPUX': self.handleCPULimit() elif what == 'resMemX': self.handleMemLimit() elif what == 'resSpcX': self.handleSpcLimit() elif what == 'resNetX': self.handleNetLimit() else: self.logger.error("unknown resource msg from deplo: '%s'" % what) pass elif which == 'reinstateCmd': self.handleReinstate() elif which == 'nicStateMsg': stateMsg = msgUpd.nicStateMsg state = str(stateMsg.nicState) self.handleNICStateChange(state) elif which == 'peerInfoMsg': peerMsg = msgUpd.peerInfoMsg state = str(peerMsg.peerState) uuid = peerMsg.uuid self.handlePeerStateChange(state, uuid) else: self.logger.error("unknown msg from deplo: '%s'" % which) pass def handleReinstate(self): self.logger.info('handleReinstate') self.poller.unregister(self.discoChannel) self.disco.reconnect() self.discoChannel = self.disco.channel self.poller.register(self.discoChannel, zmq.POLLIN) for inst in self.components: self.components[inst].handleReinstate() def handleNICStateChange(self, state): ''' Handle the NIC state change message: notify components ''' self.logger.info("handleNICStateChange") for component in self.components.values(): component.handleNICStateChange(state) def handlePeerStateChange(self, state, uuid): ''' Handle the peer state change message: notify components ''' self.logger.info("handlePeerStateChange") for component in self.components.values(): component.handlePeerStateChange(state, uuid) def handleCPULimit(self): ''' Handle the case when the CPU limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleCPULimit") for component in self.components.values(): component.handleCPULimit() def handleMemLimit(self): ''' Handle the case when the memory limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleMemLimit") for component in self.components.values(): component.handleMemLimit() def handleSpcLimit(self): ''' Handle the case when the file space limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleSpcLimit") for component in self.components.values(): component.handleSpcLimit() def handleNetLimit(self): ''' Handle the case when the net usage limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleNetLimit") for component in self.components.values(): component.handleNetLimit() def handleEventReport(self, part, msg): '''Handle event report from a part The event report is forwarded to the deplo service. ''' partName = part.getName() typeName = part.getTypeName() bundle = ( partName, typeName, ) + (msg, ) self.deplc.reportEvent(bundle) def terminate(self): '''Terminate all functions of the actor. Terminate all components, and connections to the deplo/disco services. Finally exit the process. ''' self.logger.info("terminating") for component in self.components.values(): component.terminate() time.sleep(1.0) self.deplc.terminate() self.disco.terminate() if self.auth: self.auth.stop() # Clean up everything # self.context.destroy() # time.sleep(1.0) self.logger.info("terminated") os._exit(0)
class RpcServer: """""" def __init__(self) -> None: """ Constructor """ # Save functions dict: key is function name, value is function object self._functions: Dict[str, Callable] = {} # Zmq port related self._context: zmq.Context = zmq.Context() # Reply socket (Request–reply pattern) self._socket_rep: zmq.Socket = self._context.socket(zmq.REP) # Publish socket (Publish–subscribe pattern) self._socket_pub: zmq.Socket = self._context.socket(zmq.PUB) # Worker thread related self._active: bool = False # RpcServer status self._thread: threading.Thread = None # RpcServer thread self._lock: threading.Lock = threading.Lock() # Heartbeat related self._heartbeat_at: int = None # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None def is_active(self) -> bool: """""" return self._active def start(self, rep_address: str, pub_address: str, username: str = "", password: str = "", server_secretkey_path: str = "") -> None: """ Start RpcServer """ if self._active: return # Start authenticator if server_secretkey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( server_secretkey_path) self.__socket_pub.curve_secretkey = secretkey self.__socket_pub.curve_publickey = publickey self.__socket_pub.curve_server = True self.__socket_rep.curve_secretkey = secretkey self.__socket_rep.curve_publickey = publickey self.__socket_rep.curve_server = True elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_pub.plain_server = True self.__socket_rep.plain_server = True # Bind socket address self._socket_rep.bind(rep_address) self._socket_pub.bind(pub_address) # Start RpcServer status self._active = True # Start RpcServer thread self._thread = threading.Thread(target=self.run) self._thread.start() # Init heartbeat publish timestamp self._heartbeat_at = time() + HEARTBEAT_INTERVAL def stop(self) -> None: """ Stop RpcServer """ if not self._active: return # Stop RpcServer status self._active = False def join(self) -> None: # Wait for RpcServer thread to exit if self._thread and self._thread.is_alive(): self._thread.join() self._thread = None def run(self) -> None: """ Run RpcServer functions """ while self._active: # Poll response socket for 1 second n: int = self._socket_rep.poll(1000) self.check_heartbeat() if not n: continue # Receive request data from Reply socket req = self._socket_rep.recv_pyobj() # Get function name and parameters name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails try: func: Callable = self._functions[name] r: Any = func(*args, **kwargs) rep: list = [True, r] except Exception as e: # noqa rep: list = [False, traceback.format_exc()] # send callable response by Reply socket self._socket_rep.send_pyobj(rep) # Unbind socket address self._socket_pub.unbind(self._socket_pub.LAST_ENDPOINT) self._socket_rep.unbind(self._socket_rep.LAST_ENDPOINT) if self.__authenticator: self.__authenticator.stop() def publish(self, topic: str, data: Any) -> None: """ Publish data """ with self._lock: self._socket_pub.send_pyobj([topic, data]) def register(self, func: Callable) -> None: """ Register function """ self._functions[func.__name__] = func def check_heartbeat(self) -> None: """ Check whether it is required to send heartbeat. """ now: float = time() if now >= self._heartbeat_at: # Publish heartbeat self.publish(HEARTBEAT_TOPIC, now) # Update timestamp of next publish self._heartbeat_at = now + HEARTBEAT_INTERVAL
def run(self): self.set_status("Client Startup") self.set_status("Creating zmq Contexts",1) clientctx = zmq.Context() self.set_status("Starting zmq ThreadedAuthenticator",1) #clientauth = zmq.auth.ThreadedAuthenticator(clientctx) clientauth = ThreadAuthenticator(clientctx) clientauth.start() with taco.globals.settings_lock: publicdir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/public/")) privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/private/")) self.set_status("Configuring Curve to use publickey dir:" + publicdir) clientauth.configure_curve(domain='*', location=publicdir) poller = zmq.Poller() while not self.stop.is_set(): #logging.debug("PRE") result = self.sleep.wait(0.1) #logging.debug(result) self.sleep.clear() if self.stop.is_set(): break if abs(time.time() - self.connect_block_time) > 1: with taco.globals.settings_lock: self.max_upload_rate = taco.globals.settings["Upload Limit"] * taco.constants.KB with taco.globals.settings_lock: self.max_download_rate = taco.globals.settings["Download Limit"] * taco.constants.KB self.chunk_request_rate = float(taco.constants.FILESYSTEM_CHUNK_SIZE) / float(self.max_download_rate) #logging.debug(str((self.max_download_rate,taco.constants.FILESYSTEM_CHUNK_SIZE,self.chunk_request_rate))) self.connect_block_time = time.time() with taco.globals.settings_lock: for peer_uuid in taco.globals.settings["Peers"].keys(): if taco.globals.settings["Peers"][peer_uuid]["enabled"]: #init some defaults if not peer_uuid in self.client_reconnect_mod: self.client_reconnect_mod[peer_uuid] = taco.constants.CLIENT_RECONNECT_MIN if not peer_uuid in self.client_connect_time: self.client_connect_time[peer_uuid] = time.time() + self.client_reconnect_mod[peer_uuid] if not peer_uuid in self.client_timeout: self.client_timeout[peer_uuid] = time.time() + taco.constants.ROLLCALL_TIMEOUT if time.time() >= self.client_connect_time[peer_uuid]: if peer_uuid not in self.clients.keys(): self.set_status("Starting Client for: " + peer_uuid) try: ip_of_client = socket.gethostbyname(taco.globals.settings["Peers"][peer_uuid]["hostname"]) except: self.set_status("Starting of client failed due to bad dns lookup:" + peer_uuid) continue self.clients[peer_uuid] = clientctx.socket(zmq.DEALER) self.clients[peer_uuid].setsockopt(zmq.LINGER, 0) client_public, client_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-client.key_secret"))) self.clients[peer_uuid].curve_secretkey = client_secret self.clients[peer_uuid].curve_publickey = client_public self.clients[peer_uuid].curve_serverkey = str(taco.globals.settings["Peers"][peer_uuid]["serverkey"]) self.clients[peer_uuid].connect("tcp://" + ip_of_client + ":" + str(taco.globals.settings["Peers"][peer_uuid]["port"])) self.next_rollcall[peer_uuid] = time.time() with taco.globals.high_priority_output_queue_lock: taco.globals.high_priority_output_queue[peer_uuid] = Queue.Queue() with taco.globals.medium_priority_output_queue_lock: taco.globals.medium_priority_output_queue[peer_uuid] = Queue.Queue() with taco.globals.low_priority_output_queue_lock: taco.globals.low_priority_output_queue[peer_uuid] = Queue.Queue() with taco.globals.file_request_output_queue_lock: taco.globals.file_request_output_queue[peer_uuid] = Queue.Queue() poller.register(self.clients[peer_uuid],zmq.POLLIN) if len(self.clients.keys()) == 0: continue peer_keys = self.clients.keys() random.shuffle(peer_keys) for peer_uuid in peer_keys: #self.set_status("Socket Write Possible:" + peer_uuid) #high priority queue processing with taco.globals.high_priority_output_queue_lock: while not taco.globals.high_priority_output_queue[peer_uuid].empty(): self.set_status("high priority output q not empty:" + peer_uuid) data = taco.globals.high_priority_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #medium priority queue processing with taco.globals.medium_priority_output_queue_lock: while not taco.globals.medium_priority_output_queue[peer_uuid].empty(): self.set_status("medium priority output q not empty:" + peer_uuid) data = taco.globals.medium_priority_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #filereq q, aka the download throttle if time.time() >= self.file_request_time: self.file_request_time = time.time() with taco.globals.file_request_output_queue_lock: if not taco.globals.file_request_output_queue[peer_uuid].empty(): with taco.globals.download_limiter_lock: download_rate = taco.globals.download_limiter.get_rate() bw_percent = download_rate / self.max_download_rate wait_time = self.chunk_request_rate * bw_percent #self.set_status(str((download_rate,self.max_download_rate,self.chunk_request_rate,bw_percent,wait_time))) if wait_time > 0.01: self.file_request_time += wait_time if download_rate < self.max_download_rate: self.set_status("filereq output q not empty+free bw:" + peer_uuid) data = taco.globals.file_request_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #low priority queue processing with taco.globals.low_priority_output_queue_lock: if not taco.globals.low_priority_output_queue[peer_uuid].empty(): with taco.globals.upload_limiter_lock: upload_rate = taco.globals.upload_limiter.get_rate() if upload_rate < self.max_upload_rate: self.set_status("low priority output q not empty+free bw:" + peer_uuid) data = taco.globals.low_priority_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #rollcall special case if self.next_rollcall[peer_uuid] < time.time(): #self.set_status("Requesting Rollcall from: " + peer_uuid) data = taco.commands.Request_Rollcall() self.clients[peer_uuid].send_multipart(['',data]) with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) self.next_rollcall[peer_uuid] = time.time() + random.randint(taco.constants.ROLLCALL_MIN,taco.constants.ROLLCALL_MAX) self.sleep.set() #continue #RECEIVE BLOCK socks = dict(poller.poll(0)) while self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLIN: #self.set_status("Socket Read Possible") sink,data = self.clients[peer_uuid].recv_multipart() with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data)) self.set_client_last_reply(peer_uuid) self.next_request = taco.commands.Process_Reply(peer_uuid,data) if self.next_request != "": with taco.globals.medium_priority_output_queue_lock: taco.globals.medium_priority_output_queue[peer_uuid].put(self.next_request) self.sleep.set() socks = dict(poller.poll(0)) #cleanup block self.error_msg = [] if self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLERR: self.error_msg.append("got a socket error") if abs(self.client_timeout[peer_uuid] - time.time()) > taco.constants.ROLLCALL_TIMEOUT: self.error_msg.append("havn't seen communications") if len(self.error_msg) > 0: self.set_status("Stopping client: " + peer_uuid + " -- " + " and ".join(self.error_msg),2) poller.unregister(self.clients[peer_uuid]) self.clients[peer_uuid].close(0) del self.clients[peer_uuid] del self.client_timeout[peer_uuid] with taco.globals.high_priority_output_queue_lock: del taco.globals.high_priority_output_queue[peer_uuid] with taco.globals.medium_priority_output_queue_lock: del taco.globals.medium_priority_output_queue[peer_uuid] with taco.globals.low_priority_output_queue_lock: del taco.globals.low_priority_output_queue[peer_uuid] with taco.globals.file_request_output_queue_lock: del taco.globals.file_request_output_queue[peer_uuid] self.client_reconnect_mod[peer_uuid] = min(self.client_reconnect_mod[peer_uuid] + taco.constants.CLIENT_RECONNECT_MOD,taco.constants.CLIENT_RECONNECT_MAX) self.client_connect_time[peer_uuid] = time.time() + self.client_reconnect_mod[peer_uuid] self.set_status("Terminating Clients") for peer_uuid in self.clients.keys(): self.clients[peer_uuid].close(0) self.set_status("Stopping zmq ThreadedAuthenticator") clientauth.stop() clientctx.term() self.set_status("Clients Exit")
class CombaZMQAdapter(threading.Thread, CombaBase): def __init__(self, port): self.port = str(port) threading.Thread.__init__ (self) self.shutdown_event = Event() self.context = zmq.Context().instance() self.authserver = ThreadAuthenticator(self.context) self.loadConfig() self.start() #------------------------------------------------------------------------------------------# def run(self): """ run runs on function start """ self.startAuthserver() self.data = '' self.socket = self.context.socket(zmq.REP) self.socket.plain_server = True self.socket.bind("tcp://*:"+self.port) self.shutdown_event.clear() self.controller = CombaController(self, self.lqs_socket, self.lqs_recorder_socket) self.controller.messenger.setMailAddresses(self.get('frommail'), self.get('adminmail')) self.can_send = False # Process tasks forever while not self.shutdown_event.is_set(): self.data = self.socket.recv() self.can_send = True data = self.data.split(' ') command = str(data.pop(0)) params = "()" if len(data) < 1 else "('" + "','".join(data) + "')" try: exec"a=self.controller." + command + params except SyntaxError: self.controller.message('Warning: Syntax Error') except AttributeError: print "Warning: Method " + command + " does not exist" self.controller.message('Warning: Method ' + command + ' does not exist') except TypeError: print "Warning: Wrong number of params" self.controller.message('Warning: Wrong number of params') except: print "Warning: Unknown Error" self.controller.message('Warning: Unknown Error') return #------------------------------------------------------------------------------------------# def halt(self): """ Stop the server """ if self.shutdown_event.is_set(): return try: del self.controller except: pass self.shutdown_event.set() result = 'failed' try: result = self.socket.unbind("tcp://*:"+self.port) except: pass #self.socket.close() #------------------------------------------------------------------------------------------# def reload(self): """ stop, reload config and startagaing """ if self.shutdown_event.is_set(): return self.loadConfig() self.halt() time.sleep(3) self.run() #------------------------------------------------------------------------------------------# def send(self,message): """ Send a message to the client :param message: string """ if self.can_send: self.socket.send(message) self.can_send = False #------------------------------------------------------------------------------------------# def startAuthserver(self): """ Start zmq authentification server """ # stop auth server if running if self.authserver.is_alive(): self.authserver.stop() if self.securitylevel > 0: # Authentifizierungsserver starten. self.authserver.start() # Bei security level 2 auch passwort und usernamen verlangen if self.securitylevel > 1: try: addresses = CombaWhitelist().getList() for address in addresses: self.authserver.allow(address) except: pass # Instruct authenticator to handle PLAIN requests self.authserver.configure_plain(domain='*', passwords=self.getAccounts()) #------------------------------------------------------------------------------------------# def getAccounts(self): """ Get accounts from redis db :return: llist - a list of accounts """ accounts = CombaUser().getLogins() db = redis.Redis() internaccount = db.get('internAccess') if not internaccount: user = ''.join(random.sample(string.lowercase,10)) password = ''.join(random.sample(string.lowercase+string.uppercase+string.digits,22)) db.set('internAccess', user + ':' + password) intern = [user, password] else: intern = internaccount.split(':') accounts[intern[0]] = intern[1] return accounts
class Facilities(furnace_runtime.FurnaceRuntime): """ Main Facilities class. """ def __init__(self, debug, u2, kp, qt, it, ip, fac): """ Constructor, ZMQ connections to the APP partition and backend are built here. :param u2: The UUID for the corresponding app instance. :param kp: The keypair to use, in the form {'be_key': '[path]', 'app_key': '[path]'} :param qt: Custom quotas dict. :param ip: Connect to this proxy IP. :param it: Connect to this proxy TCP port. :param fac: The path to the FAC partition's ZMQ socket. """ super(Facilities, self).__init__(debug=debug) self.debug = debug self.u2 = u2 self.fac_path = fac self.it = it self.ip = ip self.already_shutdown = False self.tprint(DEBUG, "starting Facilities constructor") self.tprint(DEBUG, f"running as UUID={u2}") api = api_implementation._default_implementation_type if api == "python": raise Exception("not running in cpp accelerated mode") self.tprint(DEBUG, f"running in api mode: {api}") self.fmsg_in = facilities_pb2.FacMessage() self.fmsg_out = facilities_pb2.FacMessage() self.pmsg_in = facilities_pb2.FacMessage() self.pmsg_out = facilities_pb2.FacMessage() self.storage = {} self.io_id = 0 self.disk_usage = 0 self.max_disk_usage = int(qt["_furnace_max_disk"]) self.tprint(DEBUG, f"MAX_DISK_USAGE {self.max_disk_usage}") self.context = zmq.Context(io_threads=4) self.poller = zmq.Poller() # connections with the proxy self.tprint(DEBUG, f"loading crypto...") self.auth = ThreadAuthenticator(self.context) # crypto bootstrap self.auth.start() self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publisher_public, _ = zmq.auth.load_certificate(kp["be_pub"]) public, secret = zmq.auth.load_certificate(kp["app_key"]) assert self.auth.is_alive() self.tprint(DEBUG, f"loading crypto... DONE") TCP_PXY_RTR = f"tcp://{str(ip)}:{int(it)+0}" TCP_PXY_PUB = f"tcp://{str(ip)}:{int(it)+1}" # connections to the proxy s = f"tcp://{self.ip}:{self.it}" self.dealer_pxy = self.context.socket(zmq.DEALER) # async directed self.didentity = f"{u2}-d" # this must match the FE self.dealer_pxy.identity = self.didentity.encode() self.dealer_pxy.curve_publickey = public self.dealer_pxy.curve_secretkey = secret self.dealer_pxy.curve_serverkey = publisher_public self.tprint(INFO, f"FAC--PXY: Connecting as Dealer to {TCP_PXY_RTR}") self.dealer_pxy.connect(s) self.req_pxy = self.context.socket(zmq.REQ) # sync self.ridentity = f"{u2}-r" self.req_pxy.identity = self.ridentity.encode() self.req_pxy.curve_publickey = public self.req_pxy.curve_secretkey = secret self.req_pxy.curve_serverkey = publisher_public self.req_pxy.connect(s) self.sub_pxy = self.context.socket(zmq.SUB) # async broadcast self.sub_pxy.curve_publickey = public self.sub_pxy.curve_secretkey = secret self.sub_pxy.curve_serverkey = publisher_public self.sub_pxy.setsockopt(zmq.SUBSCRIBE, b"") self.tprint(INFO, f"FAC--PXY: Connecting as Publisher to {TCP_PXY_PUB}") self.sub_pxy.connect(TCP_PXY_PUB) self.poller.register(self.dealer_pxy, zmq.POLLIN) self.poller.register(self.sub_pxy, zmq.POLLIN) # connections with the frontend s = f"ipc://{self.fac_path}{os.sep}fac_rtr" self.rtr_fac = self.context.socket(zmq.ROUTER) # fe sync/async msgs self.tprint(INFO, f"APP--FAC: Binding as Router to {s}") try: self.rtr_fac.bind(s) except zmq.error.ZMQError: self.tprint(ERROR, f"rtr_fac {s} is already in use!") sys.exit() s = f"ipc://{self.fac_path}{os.sep}fac_pub" self.pub_fac = self.context.socket(zmq.PUB) # bcast to fe self.tprint(INFO, f"APP--FAC: Binding as Publisher to {s}") try: self.pub_fac.bind(s) except zmq.error.ZMQError: self.tprint(ERROR, f"pub_fac {s} is already in use!") sys.exit() self.poller.register(self.rtr_fac, zmq.POLLIN) self.tprint(DEBUG, "leaving Facilities constructor") def loop(self): """ Main event loop. Polls (with timeout) on ZMQ sockets waiting for data from the APP partition and backend. :returns: Nothing. """ while True: socks = dict(self.poller.poll(TIMEOUT_FAC)) # inbound async from BE (sync should not arrive here) if self.dealer_pxy in socks: sys.stdout.write("v") sync = ASYNC raw_msg = self.dealer_pxy.recv() self.msgin += 1 raw_out = self.be2fe_process_protobuf(raw_msg, sync) # inbound broadcast from BE if self.sub_pxy in socks: sys.stdout.write("v") raw_msg = self.sub_pxy.recv() self.msgin += 1 self.pmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "BROADCAST FROM PXY\n%s" % (self.pmsg_in, )) index = 0 submsg_list = getattr(self.pmsg_in, "be_msg") m = submsg_list[index] self.pmsg_out.Clear() msgnum = self.pmsg_out.__getattribute__("BE_MSG") self.pmsg_out.type.append(msgnum) submsg = getattr(self.pmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.pmsg_out.SerializeToString() self.tprint(DEBUG, "to FE {len(raw_str)}B {self.pmsg_out}") self.pub_fac.send(raw_str) self.msgout += 1 # inbound from FE if self.rtr_fac in socks: sys.stdout.write("^") pkt = self.rtr_fac.recv_multipart() self.msgin += 1 if len(pkt) == 3: # sync sync = SYNC ident, empty, raw_msg = pkt raw_out = self.fe2be_process_protobuf(raw_msg, sync) self.rtr_fac.send_multipart( [ident.encode(), b"", raw_out.encode()]) self.msgout += 1 elif len(pkt) == 2: # async sync = ASYNC ident, raw_msg = pkt self.fe2be_process_protobuf(raw_msg, sync) if not socks: sys.stdout.write(".") sys.stdout.flush() self.tick += 1 def be2fe_process_protobuf(self, raw_msg, sync): """ Processes messages sent by the backend. :returns: Nothing (modifies class's protobufs). """ self.pmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "FROM BE\n%s" % (self.pmsg_in, )) for t in self.pmsg_in.type: msgtype = facilities_pb2.FacMessage.Type.Name( t) # msgtype = BE_MSG msglist = getattr(self.pmsg_in, msgtype.lower()) m = msglist.pop(0) # the particular message if msgtype == "BE_MSG": self.fmsg_out.Clear() msgnum = self.fmsg_out.__getattribute__("BE_MSG") self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.fmsg_out.SerializeToString() self.tprint( DEBUG, "to FE {self.didentity} {len(raw_str)}B {self.fmsg_out}") self.rtr_fac.send_multipart([self.didentity.encode(), raw_str]) self.msgout += 1 return # Split off between pass-through to BE or fac-provided IO. def fe2be_process_protobuf(self, raw_msg, sync): """ Processes messages sent by the app. Some messages are processed locally (IO). Others are forwarded to the backend. :returns: Raw serialized protobuf. """ self.fmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "FROM FE\n%s" % (self.fmsg_in, )) self.fmsg_out.Clear() for t in self.fmsg_in.type: msgtype = facilities_pb2.FacMessage.Type.Name( t) # msgtype = BE_MSG msglist = getattr(self.fmsg_in, msgtype.lower()) m = msglist.pop(0) # the particular message if msgtype == "BE_MSG": # Pass-through to BE self.fe2be_process_msg(m, sync) elif msgtype == "GET": # IO self.process_get(m) elif msgtype == "SET": # IO self.process_set(m) self.tprint(DEBUG, "sending {self.fmsg_out}") raw_out = self.fmsg_out.SerializeToString() return raw_out def fe2be_process_msg(self, m, sync): """ Processes messages sent by the app. All these are destined to the backend. :returns: Nothing (modifies class's protobufs). """ self.pmsg_out.Clear() msgnum = self.pmsg_out.__getattribute__("BE_MSG") self.pmsg_out.type.append(msgnum) submsg = getattr(self.pmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.pmsg_out.SerializeToString() self.tprint(DEBUG, "sending to BE %dB %s" % (len(raw_str), self.pmsg_out)) if sync == SYNC: self.req_pxy.send(raw_str) self.msgout += 1 self.pmsg_in.ParseFromString(self.req_pxy.recv()) self.msgin += 1 self.tprint(DEBUG, "recv from BE %dB %s" % (len(raw_str), self.pmsg_in)) index = 0 submsg_list = getattr(self.pmsg_in, "be_msg_ret") status = submsg_list[index].status val = submsg_list[index].value self.fmsg_out.Clear() msgnum = self.fmsg_out.__getattribute__("BE_MSG_RET") self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, "be_msg_ret").add() submsg.status = status submsg.value = val # raw_str = self.pmsg_out.SerializeToString() # return raw_str return elif sync == ASYNC: self.dealer_pxy.send(raw_str) self.msgout += 1 return def process_get(self, m): """ The app is asking for previously stored data. :param m: The protobuf message sent by the app. :returns: Nothing (modifies class's protobufs). """ key = str(m.key) msgtype = "GET_RET" msgnum = self.fmsg_out.__getattribute__(msgtype) self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, msgtype.lower()).add() submsg.key = key if not self.get_name_ok(key): submsg.value = "" submsg.result = VMI_FAILURE return try: io_id = self.storage[key] with open(FAC_IO_DIR + str(io_id), "rb") as fh: submsg.value = fh.read().decode() submsg.result = VMI_SUCCESS except IndexError: submsg.value = "" submsg.result = VMI_FAILURE return def process_set(self, m): """ The app has requested we store this data. :param m: The protobuf message sent by the app. :returns: Nothing (modifies class's protobufs). """ key = str(m.key) value = str(m.value) msgtype = "SET_RET" msgnum = self.fmsg_out.__getattribute__(msgtype) self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, msgtype.lower()).add() submsg.key = key if not self.get_name_ok(key): self.tprint(ERROR, f"key not allowed") submsg.value = "" submsg.result = VMI_FAILURE return self.tprint( DEBUG, f"cur={self.disk_usage}, req={len(value.encode())}, max={MAX_DISK_USAGE}", ) if self.disk_usage + len(value.encode()) > MAX_DISK_USAGE: self.tprint(ERROR, f"disk usage exceeded") submsg.result = VMI_FAILURE return try: # io_id = str(uuid.uuid4()) self.tprint(DEBUG, f"checking for existing key {key}") # overwrite a key try: existing_id = self.storage[key] self.tprint(DEBUG, f"found!") except KeyError: self.tprint(DEBUG, f"not found") pass else: s = os.path.getsize(FAC_IO_DIR + str(existing_id)) self.tprint(DEBUG, f"deleting existing file of size={s}") self.disk_usage -= s os.remove(FAC_IO_DIR + str(existing_id)) del self.storage[key] io_id = self.io_id self.io_id += 1 self.tprint(DEBUG, f"will assign key {io_id}") self.tprint(DEBUG, f"writing file of size {len(value.encode())}") self.disk_usage += len(value) self.tprint(DEBUG, f"disk usage at {self.disk_usage}") self.storage[key] = io_id with open(FAC_IO_DIR + str(io_id), "wb") as fh: fh.write(value.encode()) submsg.result = VMI_SUCCESS except: submsg.result = VMI_FAILURE return def get_name_ok(self, n): """ :param n: The key (must be alphanum and <=16 characters). :returns: True if in accordance with naming convention, False if otherwise. """ return 0 < len(n) < 17 and re.match("^[\w-]+$", n) def check_encoding(self, val): """ Fun. Now with more Unicode. :param val: Is this variable a string? :returns: Nothing. :raises: TypeError if val is not a string. """ if not isinstance(val, str): raise TypeError("Error: value is not of type str") def shutdown(self): """ Exit cleanly. :returns: Nothing. """ if self.already_shutdown: self.tprint(ERROR, "already shutdown?") return self.already_shutdown = True for key in self.storage: io_id = self.storage[key] self.tprint(INFO, "removing io: %s -> %s" % (key, io_id)) os.remove(FAC_IO_DIR + str(io_id)) self.auth.stop() self.poller.unregister(self.dealer_pxy) self.poller.unregister(self.sub_pxy) self.dealer_pxy.close() self.sub_pxy.close() self.poller.unregister(self.rtr_fac) self.rtr_fac.close() self.pub_fac.close() self.context.destroy() super(Facilities, self).shutdown()
def main(): """ Runs SEND either in transmitter or receiver mode """ parser = argparse.ArgumentParser() parser.add_argument( "-t", "--transmit", action="store_true", help="Flag indicating that user will be transmitting files" ) parser.add_argument( "-r", "--receive", action="store_true", help="Flag indicating that user will be receiving files" ) parser.add_argument( "--location", help="Location of files to send/receive. Can be a specific file if tx." ) parser.add_argument( "--ip", help="IP Address to form connection with" ) parser.add_argument( "--port", nargs='?', const=6000, default=6000, type=int, help="Port to form connection with (only needed if using non-default)" ) parser.add_argument( "--public_key", nargs='?', help="Public Key of transmitter in plain-text (only needed if receiver)" ) args=parser.parse_args() # Security Authentication Thread _generate_security_keys() authenticator = ThreadAuthenticator(manager.ctx) authenticator.start() whitelist = [ "127.0.0.1", args.ip ] authenticator.allow(*whitelist) authenticator.configure_curve(domain="*", location=PUBKEYS) try: if args.transmit: thread = manager.publish_folder( args.port, args.location ) elif args.receive: thread = manager.subscribe_folder( args.ip, args.port, args.location, args.public_key ) else: raise ValueError(f"User did not specify transmit/receive") except (OSError, ValueError): raise except KeyboardInterrupt: pass finally: # Keep things rolling until the transfer is done or the thread dies from # timing out while thread.isAlive(): pass thread.join() # Clean up and close everything out authenticator.stop() # Use destroy versus term: https://github.com/zeromq/pyzmq/issues/991 manager.ctx.destroy()
class MultiNodeAgent(BEMOSSAgent): def __init__(self, *args, **kwargs): super(MultiNodeAgent, self).__init__(*args, **kwargs) self.multinode_status = dict() self.getMultinodeData() self.agent_id = 'multinodeagent' self.is_parent = False self.last_sync_with_parent = datetime(1991, 1, 1) #equivalent to -ve infinitive self.parent_node = None self.recently_online_node_list = [] # initialize to lists to empty self.recently_offline_node_list = [ ] # they will be filled as nodes are discovered to be online/offline self.setup() self.runPeriodically(self.send_heartbeat, 20) self.runPeriodically(self.check_health, 60, start_immediately=False) self.runPeriodically(self.sync_all_with_parent, 600) self.subscribe('relay_message', self.relayDirectMessage) self.subscribe('update_multinode_data', self.updateMultinodeData) self.runContinuously(self.pollClients) self.run() def getMultinodeData(self): self.multinode_data = db_helper.get_multinode_data() self.nodelist_dict = { node['name']: node for node in self.multinode_data['known_nodes'] } self.node_name_list = [ node['name'] for node in self.multinode_data['known_nodes'] ] self.address_list = [ node['address'] for node in self.multinode_data['known_nodes'] ] self.server_key_list = [ node['server_key'] for node in self.multinode_data['known_nodes'] ] self.node_name = self.multinode_data['this_node'] for index, node in enumerate(self.multinode_data['known_nodes']): if node['name'] == self.node_name: self.node_index = index break else: raise ValueError( '"this_node:" entry on the multinode_data json file is invalid' ) for node_name in self.node_name_list: #initialize all nodes data if node_name not in self.multinode_status: #initialize new nodes. There could be already the node if this getMultiNode # data is being called later self.multinode_status[node_name] = dict() self.multinode_status[node_name][ 'health'] = -10 #initialized; never online/offline self.multinode_status[node_name]['last_sync_time'] = datetime( 1991, 1, 1) self.multinode_status[node_name]['last_online_time'] = None self.multinode_status[node_name]['last_offline_time'] = None self.multinode_status[node_name]['last_scanned_time'] = None def setup(self): print "Setup" base_dir = settings.PROJECT_DIR + "/" public_keys_dir = os.path.abspath(os.path.join(base_dir, 'public_keys')) secret_keys_dir = os.path.abspath( os.path.join(base_dir, 'private_keys')) self.secret_keys_dir = secret_keys_dir self.public_keys_dir = public_keys_dir if not (os.path.exists(public_keys_dir) and os.path.exists(secret_keys_dir)): logging.critical( "Certificates are missing - run generate_certificates.py script first" ) sys.exit(1) ctx = zmq.Context.instance() self.ctx = ctx # Start an authenticator for this context. self.auth = ThreadAuthenticator(ctx) self.auth.start() self.configure_authenticator() server = ctx.socket(zmq.PUB) server_secret_key_filename = self.multinode_data['known_nodes'][ self.node_index]['server_secret_key'] server_secret_file = os.path.join(secret_keys_dir, server_secret_key_filename) server_public, server_secret = zmq.auth.load_certificate( server_secret_file) server.curve_secretkey = server_secret server.curve_publickey = server_public server.curve_server = True # must come before bind server.bind( self.multinode_data['known_nodes'][self.node_index]['address']) self.server = server self.configureClient() def configure_authenticator(self): self.auth.allow() # Tell authenticator to use the certificate in a directory self.auth.configure_curve(domain='*', location=self.public_keys_dir) def disperseMessage(self, sender, topic, message): for node_name in self.node_name_list: if node_name == self.node_name: continue self.server.send( jsonify(sender, node_name + '/republish/' + topic, message)) def republishToParent(self, sender, topic, message): if self.is_parent: return #if I am parent, the message is already published for node_name in self.node_name_list: if self.multinode_status[node_name][ 'health'] == 2: #health = 2 is the parent node self.server.send( jsonify(sender, node_name + '/republish/' + topic, message)) def sync_node_with_parent(self, node_name): if self.is_parent: print "Syncing " + node_name self.last_sync_with_parent = datetime.now() sync_date_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path) # with open(self.self_database_dump_path, 'r') as f: # file_content = f.read() # msg = {'database_dump': base64.b64encode(file_content)} self.server.send( jsonify( self.node_name, node_name + '/sync-with-parent/' + sync_date_string + '/' + self.node_name, "")) def sync_all_with_parent(self, dbcon): if self.is_parent: self.last_sync_with_parent = datetime.now() sync_date_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') print "Syncing all nodes" for node_name in self.node_name_list: if node_name == self.node_name: continue # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path) # with open(self.self_database_dump_path, 'r') as f: # file_content = f.read() # msg = {'database_dump': base64.b64encode(file_content)} self.server.send( jsonify( self.node_name, node_name + '/sync-with-parent/' + sync_date_string + '/' + self.node_name, "")) def send_heartbeat(self, dbcon): #self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'}) #print 'publishing' print "Sending heartbeat" last_sync_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') self.server.send( jsonify( self.node_name, 'heartbeat/' + self.node_name + '/' + str(self.is_parent) + '/' + last_sync_string, "")) def extract_ip(self, addr): return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})', addr).groups()[0] def getNodeId(self, node_name): for index, node in enumerate(self.multinode_data['known_nodes']): if node['name'] == node_name: node_index = index break else: raise ValueError('the node name: ' + node_name + ' is not found in multinode data') return node_index def getNodeName(self, node_id): return self.multinode_data['known_nodes'][node_id]['name'] def handle_offline_nodes(self, dbcon, node_name_list): if self.is_parent: # start all the agents belonging to that node on this node command_group = [] for node_name in node_name_list: node_id = self.getNodeId(node_name) #put the offline event into cassandra events log table, and also create notification self.EventRegister(dbcon, 'node-offline', reason='communication-error', source=node_name) # get a list of agents that were supposedly running in that offline node dbcon.execute( "SELECT agent_id FROM " + node_devices_table + " WHERE assigned_node_id=%s", (node_id, )) if dbcon.rowcount: agent_ids = dbcon.fetchall() for agent_id in agent_ids: message = dict() message[STATUS_CHANGE.AGENT_ID] = agent_id[0] message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE.AGENT_STATUS] = 'start' message[ STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY command_group += [message] dbcon.execute( "UPDATE " + node_devices_table + " SET current_node_id=(%s), date_move=(%s)" " WHERE agent_id=(%s)", (self.node_index, datetime.now( pytz.UTC), agent_id[0])) dbcon.commit() print "moving agents from offline node to parent: " + str( node_name_list) print command_group if command_group: self.bemoss_publish(target='networkagent', topic='status_change', message=command_group) def handle_online_nodes(self, dbcon, node_name_list): if self.is_parent: # start all the agents belonging to that nodes back on them command_group = [] for node_name in node_name_list: node_id = self.getNodeId(node_name) if self.node_index == node_id: continue #don't handle self-online self.EventRegister(dbcon, 'node-online', reason='communication-restored', source=node_name) #get a list of agents that were supposed to be running in that online node dbcon.execute( "SELECT agent_id FROM " + node_devices_table + " WHERE assigned_node_id=%s", (node_id, )) if dbcon.rowcount: agent_ids = dbcon.fetchall() for agent_id in agent_ids: message = dict() message[STATUS_CHANGE.AGENT_ID] = agent_id[0] message[ STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE. AGENT_STATUS] = 'stop' #stop in this node command_group += [message] message = dict(message) #create another copy message[STATUS_CHANGE.NODE] = str(node_id) message[ STATUS_CHANGE. AGENT_STATUS] = 'start' #start in the target node command_group += [message] #immediately update the multnode device assignment table dbcon.execute( "UPDATE " + node_devices_table + " SET current_node_id=(%s), date_move=(%s)" " WHERE agent_id=(%s)", (node_id, datetime.now(pytz.UTC), agent_id[0])) dbcon.commit() print "Moving agents back to the online node: " + str( node_name_list) print command_group if command_group: self.bemoss_publish(target='networkagent', topic='status_change', message=command_group) def updateParent(self, dbcon, parent_node_name): parent_ip = self.extract_ip( self.nodelist_dict[parent_node_name]['address']) write_new = False if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE ): # but parent file doesn't exists write_new = True else: with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f: read_ip = f.read() if read_ip != parent_ip: write_new = True if write_new: with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f: f.write(parent_ip) if dbcon: dbcon.close() #close old connection dbcon = db_helper.db_connection( ) #start new connection using new parent_ip self.updateMultinodeData(sender=self.name, topic='update_parent', message="") def check_health(self, dbcon): for node_name, node in self.multinode_status.items(): if node['health'] > 0: #initialize all online nodes to 0. If they are really online, they should change it # back to 1 or 2 (parent) within 30 seconds throught the heartbeat. node['health'] = 0 time.sleep(30) parent_node_name = None #initialize parent node online_node_exists = False for node_name, node in self.multinode_status.items(): node['last_scanned_time'] = datetime.now() if node['health'] == 0: node['health'] = -1 node['last_offline_time'] = datetime.now() self.recently_offline_node_list += [node_name] elif node['health'] == -1: #offline since long pass elif node[ 'health'] == -10: #The node was initialized to -10, and never came online. Treat it as recently going # offline for this iteration so that the agents that were supposed to be running there can be migrated node['health'] = -1 self.recently_offline_node_list += [node_name] elif node['health'] == 2: #there is some parent node present parent_node_name = node_name if node['health'] > 0: online_node_exists = True #At-least one node (itself) should be online, if not some problem assert online_node_exists, "At least one node (current node) must be online" if not parent_node_name: #parent node doesn't exist #find a suitable node to elect a parent. The node with latest update from previous parent wins. If there is #tie, then the node coming earlier in the node-list on multinode data wins online_node_last_sync = dict( ) #only the online nodes, and their last-sync-times for node_name, node in self.multinode_status.items( ): #copy only the online nodes if node['health'] > 0: online_node_last_sync[node_name] = node['last_sync_time'] latest_node = max(online_node_last_sync, key=online_node_last_sync.get) latest_sync_date = online_node_last_sync[latest_node] for node_name in self.node_name_list: if self.multinode_status[node_name][ 'health'] <= 0: #dead nodes can't be parents continue if self.multinode_status[node_name][ 'last_sync_time'] == latest_sync_date: # this is the first node with the latest update from parent #elligible parent found self.updateParent(dbcon, node_name) if node_name == self.node_name: # I am the node, so I get to become the parent self.is_parent = True print "I am the boss now, " + self.node_name break else: #I-am-not-the-first-node with latest update; somebody else is self.is_parent = False break else: #parent node exist self.updateParent(dbcon, parent_node_name) for node in self.multinode_data['known_nodes']: print node['name'] + ': ' + str( self.multinode_status[node['name']]['health']) if self.is_parent: #if this is a parent node, update the node_info table if dbcon is None: #if no database connection exists make connection dbcon = db_helper.db_connection() tbl_node_info = settings.DATABASES['default']['TABLE_node_info'] dbcon.execute('select node_id from ' + tbl_node_info) to_be_deleted_node_ids = dbcon.fetchall() for index, node in enumerate(self.multinode_data['known_nodes']): if (index, ) in to_be_deleted_node_ids: to_be_deleted_node_ids.remove( (index, )) #don't remove this current node result = dbcon.execute( 'select * from ' + tbl_node_info + ' where node_id=%s', (index, )) node_type = 'parent' if self.multinode_status[ node['name']]['health'] == 2 else "child" node_status = "ONLINE" if self.multinode_status[ node['name']]['health'] > 0 else "OFFLINE" ip_address = self.extract_ip(node['address']) last_scanned_time = self.multinode_status[ node['name']]['last_online_time'] last_offline_time = self.multinode_status[ node['name']]['last_offline_time'] last_sync_time = self.multinode_status[ node['name']]['last_sync_time'] var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)" value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)" actual_values_list = (index, node['name'], node_type, node_status, ip_address, last_scanned_time, last_offline_time, last_sync_time) if dbcon.rowcount == 0: dbcon.execute( "insert into " + tbl_node_info + " " + var_list + " VALUES" + value_placeholder_list, actual_values_list) else: dbcon.execute( "update " + tbl_node_info + " SET " + var_list + " = " + value_placeholder_list + " where node_id = %s", actual_values_list + (index, )) dbcon.commit() for id in to_be_deleted_node_ids: dbcon.execute( 'delete from accounts_userprofile_nodes where nodeinfo_id=%s', id) #delete entries in user-profile for the old node dbcon.commit() dbcon.execute('delete from ' + tbl_node_info + ' where node_id=%s', id) #delete the old nodes dbcon.commit() if self.recently_online_node_list: #Online nodes should be handled first because, the same node can first be #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after #coming online self.handle_online_nodes(dbcon, self.recently_online_node_list) self.recently_online_node_list = [] # reset after handling if self.recently_offline_node_list: self.handle_offline_nodes(dbcon, self.recently_offline_node_list) self.recently_offline_node_list = [] # reset after handling def connect_client(self, node): server_public_file = os.path.join(self.public_keys_dir, node['server_key']) server_public, _ = zmq.auth.load_certificate(server_public_file) # The client must know the server's public key to make a CURVE connection. self.client.curve_serverkey = server_public self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/') self.client.setsockopt(zmq.SUBSCRIBE, self.node_name) self.client.connect(node['address']) def disconnect_client(self, node): self.client.disconnect(node['address']) def configureClient(self): print "Starting to receive Heart-beat" client = self.ctx.socket(zmq.SUB) # We need two certificates, one for the client and one for # the server. The client must know the server's public key # to make a CURVE connection. client_secret_key_filename = self.multinode_data['known_nodes'][ self.node_index]['client_secret_key'] client_secret_file = os.path.join(self.secret_keys_dir, client_secret_key_filename) client_public, client_secret = zmq.auth.load_certificate( client_secret_file) client.curve_secretkey = client_secret client.curve_publickey = client_public self.client = client for node in self.multinode_data['known_nodes']: self.connect_client(node) def pollClients(self, dbcon): if self.client.poll(1000): sender, topic, msg = dejsonify(self.client.recv()) topic_list = topic.split('/') if topic_list[0] == 'heartbeat': node_name = sender is_parent = topic_list[2] last_sync_with_parent = topic_list[3] if self.multinode_status[node_name][ 'health'] < 0: #the node health was <0 , means offline print node_name + " is back online" self.recently_online_node_list += [node_name] self.sync_node_with_parent(node_name) if is_parent.lower() in ['false', '0']: self.multinode_status[node_name]['health'] = 1 elif is_parent.lower() in ['true', '1']: self.multinode_status[node_name]['health'] = 2 self.parent_node = node_name else: raise ValueError( 'Invalid is_parent string in heart-beat message') self.multinode_status[node_name][ 'last_online_time'] = datetime.now() self.multinode_status[node_name][ 'last_sync_time'] = datetime.strptime( last_sync_with_parent, '%B %d, %Y, %H:%M:%S') if topic_list[0] == self.node_name: if topic_list[1] == 'sync-with-parent': pass # print topic # self.last_sync_with_parent = datetime.strptime(topic_list[2], '%B %d, %Y, %H:%M:%S') # content = base64.b64decode(msg['database_dump']) # newpath = 'bemossdb.sql' # with open(newpath, 'w') as f: # f.write(content) # try: # os.system( # 'psql -c "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid();"') # os.system( # 'dropdb bemossdb') # This step requires all connections to be closed # os.system('createdb bemossdb -O admin') # dump_result = subprocess.check_output('psql bemossdb < ' + newpath, shell=True) # except Exception as er: # print "Couldn't sync database with parent because of error: " # print er # # parent_node_name = topic_list[3] # self.updateParent(parent_node_name) if topic_list[1] == 'republish': target = msg['target'] actual_message = msg['actual_message'] actual_topic = msg['actual_topic'] self.bemoss_publish(target=target, topic=actual_topic + '/republished', message=actual_message, sender=sender) print self.node_name + ": " + topic, str(msg) else: time.sleep(2) def cleanup(self): # stop auth thread self.auth.stop() def updateMultinodeData(self, dbcon, sender, topic, message): print "Updating Multinode data" topic_list = topic.split('/') self.configure_authenticator() #to/multinodeagent/from/<doesn't matter>/update_multinode_data if topic_list[4] == 'update_multinode_data': old_multinode_data = self.multinode_data self.getMultinodeData() for node in self.multinode_data['known_nodes']: if node not in old_multinode_data['known_nodes']: print "New node has been added to the cluster: " + node[ 'name'] print "We will connect to it" self.connect_client(node) for node in old_multinode_data['known_nodes']: if node not in self.multinode_data['known_nodes']: print "Node has been removed from the cluster: " + node[ 'name'] print "We will disconnect from it" self.disconnect_client(node) # TODO: remove it from the node_info table print "yay! got it" def relayDirectMessage(self, dbcon, sender, topic, message): print topic #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui> from_entity = sender target = message['target'] actual_message = message['actual_message'] actual_topic = message['actual_topic'] for to_entity in target: if to_entity in settings.NO_FORWARD_AGENTS: return #no forwarding should be done for these agents elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS: if not self.is_parent: self.republishToParent(sender, topic, message) elif to_entity == "ALL": self.disperseMessage(sender, topic=topic, message=message) else: dbcon.execute( "SELECT current_node_id FROM " + node_devices_table + " WHERE agent_id=%s", (to_entity, )) if dbcon.rowcount: node_id = dbcon.fetchone()[0] if node_id != self.node_index: self.server.send( jsonify( sender, self.getNodeName(node_id) + '/republish/' + topic, message)) else: self.disperseMessage( sender, topic, message ) #republish to all nodes if we don't know where to send
class SecurityManager(object): """ Manages the security related settings and actions. Attributes: private_cert_dir (str): The path towards the directory that stores private certificates. public_cert_dir (str): The path towards the directory that stores public certificates. temp_cert_dir (str): The path towards the temporary directory (used for generating new certificates). If not provided, it defaults to system's temporary directory. no_encryption (bool): Enable or disable encryption at peer level. Peers that use encryption can only connect to other peers that use encryption and vice-versa. public_file (str): The path towards the public certificate of this peer. private_file (str): The path towards the private certificate of this peer. auth_thread (ThreadAuthenticator): A separate thread used by zmq to authenticate our peers. """ def __init__(self, private_cert_dir, public_cert_dir, temp_cert_dir=None, no_encryption=False, *args, **kwargs): """ Constructor. Arguments: private_cert_dir (str): The path towards the directory that stores private certificates. public_cert_dir (str): The path towards the directory that stores public certificates. temp_cert_dir (str): The path towards the temporary directory (used for generating new certificates). Defaults to system's temporary directory. no_encryption (bool): Enable or disable encryption at peer level. Peers that use encryption can only connect to other peers that use encryption and vice-versa. """ super(SecurityManager, self).__init__(*args, **kwargs) self.private_cert_dir = private_cert_dir self.public_cert_dir = public_cert_dir if temp_cert_dir is None: self.temp_cert_dir = tempfile.gettempdir() else: self.temp_cert_dir = temp_cert_dir # Paths for certificates used by this instance. self.public_file = None self.private_file = None # Security settings. self.no_encryption = no_encryption # Authentication thread. self.auth_thread = None def start_auth(self, context): """ Starts the authentication thread if encryption is enabled. Arguments: context (zmq.Context): The context where the authentication thread will belong. """ if not self.no_encryption: logger.debug("Authenticator thread is being started") self.auth_thread = ThreadAuthenticator( context=context, encoding='utf-8', log=logging.getLogger('zmq_auth')) self.auth_thread.start() self.auth_thread.thread.name = 'zmq_auth' self.auth_thread.configure_curve(domain='*', location=self.public_cert_dir) def terminate_auth(self): """ Ends the authentication thread if encryption is enabled. This method should be written defensively, as the environment might not be fully set (an exception in :meth:`p2p0mq.app.theapp.LocalPeer.create` does not prevent this method from being executed). """ if self.auth_thread is not None: logger.debug("Authenticator thread is being stopped") self.auth_thread.stop() self.auth_thread = None def prepare_cert_store(self, uuid): """ Prepares the directory structure before it can be used by our authentication system. Arguments: uuid: The unique identification of local peer. """ if not os.path.isdir(self.private_cert_dir): os.makedirs(self.private_cert_dir) if not os.path.isdir(self.public_cert_dir): os.makedirs(self.public_cert_dir) if not os.path.isdir(self.temp_cert_dir): os.makedirs(self.temp_cert_dir) self.public_file, self.private_file = \ self.cert_pair_check_gen(uuid) def cert_pair_check_gen(self, uuid): """ Checks if the certificates exist. Generates them if they don't. Arguments: uuid: The unique identification of the peer. Usually, this is the local peer. """ cert_pub = self.cert_file_by_uuid(uuid, public=True) cert_prv = self.cert_file_by_uuid(uuid, public=False) pub_exists = os.path.isfile(cert_pub) prv_exists = os.path.isfile(cert_prv) if pub_exists and prv_exists: # Both files exist. Yey. pass elif pub_exists and not prv_exists: # The public certificate exists but is unusable without # the private one. raise RuntimeError( "The public certificate has been found at %s, " "which indicates that a key has been generated, " "but the private certificate is not at %s", cert_pub, cert_prv) elif not pub_exists and prv_exists: # The private certificate exists but the public one doesn't. # We can extract the key from the private one. with open(cert_prv, 'r') as fin: data = re.sub(r'.*private-key = "(.+)"', "", fin.read(), re.MULTILINE) with open(cert_pub, 'w') as fout: fout.write(data) else: # Neither exists. public_file, secret_file = \ zmq.auth.create_certificates( self.temp_cert_dir, '%r' % time()) shutil.move(public_file, cert_pub) shutil.move(secret_file, cert_prv) return cert_pub, cert_prv def cert_file_by_uuid(self, uuid, public=True): """ Computes the path of a certificate inside the certificate store based on the name of the peer. Arguments: uuid: The unique identification of the peer. Usually, this is the local peer. public (bool): If True it retrieves the path of the public certificate, if False it retrieves the path of the private certificate. """ if isinstance(uuid, bytes): uuid = uuid.decode('utf-8') pb_vs_pv = 'key' if public else 'key_secret' return os.path.join( self.public_cert_dir if public else self.private_cert_dir, '%s.%s' % (uuid, pb_vs_pv)) def cert_key_by_uuid(self, uuid, public=True): """ Reads the key from corresponding certificate file. Arguments: uuid: The unique identification of the peer. Usually, this is the local peer. public (bool): If True it retrieves the public key, if False it retrieves the private key. """ file = self.cert_file_by_uuid(uuid=uuid, public=public) logger.debug("%s certificate for uuid %s is loaded from %s", 'Public' if public else 'Private', uuid, file) if not os.path.exists(file): return None public_key, secret_key = zmq.auth.load_certificate(file) return public_key if public else secret_key def exchange_certificates(self, other): """ Copies the certificates so that the two instances can authenticate themselves to each other. Arguments: other (SecurityManager): The security manager of the other peer. """ if self.no_encryption: logger.error("Encryption was disabled in %s", self) return if other.no_encryption: logger.error("Encryption was disabled in %s", other) return shutil.copy( self.public_file, os.path.join(other.public_cert_dir, os.path.basename(self.public_file))) shutil.copy( other.public_file, os.path.join(self.public_cert_dir, os.path.basename(other.public_file))) self.auth_thread.configure_curve(domain='*', location=self.public_cert_dir) other.auth_thread.configure_curve(domain='*', location=other.public_cert_dir)
class Command(LAVADaemonCommand): help = "LAVA log recorder" logger = None default_logfile = "/var/log/lava-server/lava-logs.log" def __init__(self, *args, **options): super(Command, self).__init__(*args, **options) self.logger = logging.getLogger("lava-logs") self.log_socket = None self.auth = None self.controler = None self.inotify_fd = None self.pipe_r = None self.poller = None self.cert_dir_path = None # List of logs self.jobs = {} # Keep test cases in memory self.test_cases = [] # Master status self.last_ping = 0 self.ping_interval = TIMEOUT def add_arguments(self, parser): super(Command, self).add_arguments(parser) net = parser.add_argument_group("network") net.add_argument('--socket', default='tcp://*:5555', help="Socket waiting for logs. Default: tcp://*:5555") net.add_argument('--master-socket', default='tcp://localhost:5556', help="Socket for master-slave communication. Default: tcp://localhost:5556") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument('--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-logs", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return # Create the sockets context = zmq.Context() self.log_socket = context.socket(zmq.PULL) self.controler = context.socket(zmq.ROUTER) self.controler.setsockopt(zmq.IDENTITY, b"lava-logs") # Limit the number of messages in the queue self.controler.setsockopt(zmq.SNDHWM, 2) # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5 # "Immediately readies that connection for data transfer with the master" self.controler.setsockopt(zmq.CONNECT_RID, b"master") if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.log_socket.setsockopt(zmq.IPV6, 1) self.controler.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate(options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except IOError as err: self.logger.error("[INIT] %s", err) self.auth.stop() return self.log_socket.curve_publickey = master_public self.log_socket.curve_secretkey = master_secret self.log_socket.curve_server = True self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_serverkey = master_public self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.cert_dir_path = options["slaves_certs"] self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.log_socket.bind(options['socket']) self.controler.connect(options['master_socket']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.log_socket, zmq.POLLIN) self.poller.register(self.controler, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] listening for logs") # PING right now: the master is waiting for this message to start # scheduling. self.controler.send_multipart([b"master", b"PING"]) try: self.main_loop() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) # Close the controler socket self.controler.close(linger=0) self.poller.unregister(self.controler) # Carefully close the logging socket as we don't want to lose messages self.logger.info("[EXIT] Disconnect logging socket and process messages") endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT)) self.logger.debug("[EXIT] unbinding from '%s'", endpoint) self.log_socket.unbind(endpoint) # Empty the queue try: while self.wait_for_messages(True): # Flush test cases cache for every iteration because we might # get killed soon. self.flush_test_cases() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Last flush self.flush_test_cases() self.logger.info("[EXIT] Closing the logging socket: the queue is empty") self.log_socket.close() if options['encrypt']: self.auth.stop() context.term() def flush_test_cases(self): if self.test_cases: self.logger.info("Saving %d test cases", len(self.test_cases)) TestCase.objects.bulk_create(self.test_cases) self.test_cases = [] def main_loop(self): last_gc = time.time() last_bulk_create = time.time() # Wait for messages # TODO: fix timeout computation while self.wait_for_messages(False): now = time.time() # Dump TestCase into the database if now - last_bulk_create > BULK_CREATE_TIMEOUT: last_bulk_create = now self.flush_test_cases() # Close old file handlers if now - last_gc > FD_TIMEOUT: last_gc = now # Iterate while removing keys is not compatible with iterator for job_id in list(self.jobs.keys()): # pylint: disable=consider-iterating-dictionary if now - self.jobs[job_id].last_usage > FD_TIMEOUT: self.logger.info("[%s] closing log file", job_id) self.jobs[job_id].close() del self.jobs[job_id] # Ping the master if now - self.last_ping > self.ping_interval: self.logger.debug("PING => master") self.last_ping = now self.controler.send_multipart([b"master", b"PING"]) def wait_for_messages(self, leaving): try: try: sockets = dict(self.poller.poll(TIMEOUT * 1000)) except zmq.error.ZMQError as exc: self.logger.error("[POLL] zmq error: %s", str(exc)) return True # Messages if sockets.get(self.log_socket) == zmq.POLLIN: self.logging_socket() return True # Signals elif sockets.get(self.pipe_r) == zmq.POLLIN: # remove the message from the queue os.read(self.pipe_r, 1) if not leaving: self.logger.info("[POLL] received a signal, leaving") return False else: self.logger.warning("[POLL] signal already handled, please wait for the process to exit") return True # Pong received elif sockets.get(self.controler) == zmq.POLLIN: self.controler_socket() return True # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", self.cert_dir_path) self.auth.configure_curve(domain='*', location=self.cert_dir_path) # Nothing received else: return not leaving except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset") connection.close() return True def logging_socket(self): msg = self.log_socket.recv_multipart() try: (job_id, message) = (u(m) for m in msg) # pylint: disable=unbalanced-tuple-unpacking except ValueError: # do not let a bad message stop the master. self.logger.error("[POLL] failed to parse log message, skipping: %s", msg) return try: scanned = yaml.load(message, Loader=yaml.CLoader) except yaml.YAMLError: self.logger.error("[%s] data are not valid YAML, dropping", job_id) return # Look for "results" level try: message_lvl = scanned["lvl"] message_msg = scanned["msg"] except TypeError: self.logger.error("[%s] not a dictionary, dropping", job_id) return except KeyError: self.logger.error( "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s", job_id, message) return # Find the handler (if available) if job_id not in self.jobs: # Query the database for the job try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return self.logger.info("[%s] receiving logs from a new job", job_id) # Create the sub directories (if needed) mkdir(job.output_dir) self.jobs[job_id] = JobHandler(job) if message_lvl == "results": try: job = TestJob.objects.get(pk=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return meta_filename = create_metadata_store(message_msg, job) new_test_case = map_scanned_results(results=message_msg, job=job, meta_filename=meta_filename) if new_test_case is None: self.logger.warning( "[%s] unable to map scanned results: %s", job_id, message) else: self.test_cases.append(new_test_case) # Look for lava.job result if message_msg.get("definition") == "lava" and message_msg.get("case") == "job": # Flush cached test cases self.flush_test_cases() if message_msg.get("result") == "pass": health = TestJob.HEALTH_COMPLETE health_msg = "Complete" else: health = TestJob.HEALTH_INCOMPLETE health_msg = "Incomplete" self.logger.info("[%s] job status: %s", job_id, health_msg) infrastructure_error = (message_msg.get("error_type") in ["Bug", "Configuration", "Infrastructure"]) if infrastructure_error: self.logger.info("[%s] Infrastructure error", job_id) # Update status. with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(health, infrastructure_error) job.save() # Mark the file handler as used self.jobs[job_id].last_usage = time.time() # n.b. logging here would produce a log entry for every message in every job. # The format is a list of dictionaries message = "- %s" % message # Write data self.jobs[job_id].write(message) def controler_socket(self): msg = self.controler.recv_multipart() try: master_id = u(msg[0]) action = u(msg[1]) ping_interval = int(msg[2]) if master_id != "master": self.logger.error("Invalid master id '%s'. Should be 'master'", master_id) return if action != "PONG": self.logger.error("Invalid answer '%s'. Should be 'PONG'", action) return except (IndexError, ValueError): self.logger.error("Invalid message '%s'", msg) return if ping_interval < TIMEOUT: self.logger.error("invalid ping interval (%d) too small", ping_interval) return self.logger.debug("master => PONG(%d)", ping_interval) self.ping_interval = ping_interval
class RpcClient: """""" def __init__(self): """Constructor""" # zmq port related self.__context: zmq.Context = zmq.Context() # Request socket (Request–reply pattern) self.__socket_req: zmq.Socket = self.__context.socket(zmq.REQ) # Subscribe socket (Publish–subscribe pattern) self.__socket_sub: zmq.Socket = self.__context.socket(zmq.SUB) # Worker thread relate, used to process data pushed from server self.__active: bool = False # RpcClient status self.__thread: threading.Thread = None # RpcClient thread self.__lock: threading.Lock = threading.Lock() # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None self._last_received_ping: datetime = datetime.utcnow() @lru_cache(100) def __getattr__(self, name: str): """ Realize remote call function """ # Perform remote call task def dorpc(*args, **kwargs): # Get timeout value from kwargs, default value is 30 seconds if "timeout" in kwargs: timeout = kwargs.pop("timeout") else: timeout = 30000 # Generate request req = [name, args, kwargs] # Send request and wait for response with self.__lock: self.__socket_req.send_pyobj(req) # Timeout reached without any data n = self.__socket_req.poll(timeout) if not n: msg = f"Timeout of {timeout}ms reached for {req}" raise RemoteException(msg) rep = self.__socket_req.recv_pyobj() # Return response if successed; Trigger exception if failed if rep[0]: return rep[1] else: raise RemoteException(rep[1]) return dorpc def start(self, req_address: str, sub_address: str, client_secretkey_path: str = "", server_publickey_path: str = "", username: str = "", password: str = "") -> None: """ Start RpcClient """ if self.__active: return # Start authenticator if client_secretkey_path and server_publickey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( client_secretkey_path) serverkey, _ = zmq.auth.load_certificate(server_publickey_path) self.__socket_sub.curve_secretkey = secretkey self.__socket_sub.curve_publickey = publickey self.__socket_sub.curve_serverkey = serverkey self.__socket_req.curve_secretkey = secretkey self.__socket_req.curve_publickey = publickey self.__socket_req.curve_serverkey = serverkey elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_sub.plain_username = username.encode() self.__socket_sub.plain_password = password.encode() self.__socket_req.plain_username = username.encode() self.__socket_req.plain_password = password.encode() # Connect zmq port self.__socket_req.connect(req_address) self.__socket_sub.connect(sub_address) # Start RpcClient status self.__active = True # Start RpcClient thread self.__thread = threading.Thread(target=self.run) self.__thread.start() self._last_received_ping = datetime.utcnow() def stop(self) -> None: """ Stop RpcClient """ if not self.__active: return # Stop RpcClient status self.__active = False def join(self) -> None: # Wait for RpcClient thread to exit if self.__thread and self.__thread.is_alive(): self.__thread.join() self.__thread = None def run(self) -> None: """ Run RpcClient function """ pull_tolerance = int(KEEP_ALIVE_TOLERANCE.total_seconds() * 1000) while self.__active: if not self.__socket_sub.poll(pull_tolerance): self.on_disconnected() continue # Receive data from subscribe socket topic, data = self.__socket_sub.recv_pyobj(flags=NOBLOCK) if topic == KEEP_ALIVE_TOPIC: self._last_received_ping = data else: # Process data by callable function self.callback(topic, data) # Close socket self.__socket_req.close() self.__socket_sub.close() if self.__authenticator: self.__authenticator.stop() def callback(self, topic: str, data: Any) -> None: """ Callable function """ raise NotImplementedError def subscribe_topic(self, topic: str) -> None: """ Subscribe data """ self.__socket_sub.setsockopt_string(zmq.SUBSCRIBE, topic) def on_disconnected(self): """ Callback when heartbeat is lost. """ print( "RpcServer has no response over {tolerance} seconds, please check you connection." .format(tolerance=KEEP_ALIVE_TOLERANCE.total_seconds()))
class StratusApp(StratusServerApp): def __init__(self, core: StratusCore, **kwargs): StratusServerApp.__init__(self, core, **kwargs) self.logger = StratusLogger.getLogger() self.active = True self.parms = self.getConfigParms('stratus') self.client_address = self.parms.get("client_address", "*") self.request_port = self.parms.get("request_port", 4556) self.response_port = self.parms.get("response_port", 4557) self.active_handlers = {} self.getCertDirs() def getCertDirs( self ): # These directories are generated by the generate_certificates script self.cert_dir = self.parms.get("certificate_path", os.path.expanduser("~/.stratus/zmq")) self.logger.info( f"Loading certificates and keys from directory {self.cert_dir}") self.keys_dir = os.path.join(self.cert_dir, 'certificates') self.public_keys_dir = os.path.join(self.cert_dir, 'public_keys') self.secret_keys_dir = os.path.join(self.cert_dir, 'private_keys') if not (os.path.exists(self.keys_dir) and os.path.exists(self.public_keys_dir) and os.path.exists(self.secret_keys_dir)): from stratus.handlers.zeromq.security.generate_certificates import generate_certificates generate_certificates(self.cert_dir) def initSocket(self): try: server_secret_file = os.path.join(self.secret_keys_dir, "server.key_secret") server_public, server_secret = zmq.auth.load_certificate( server_secret_file) # TODO: this is commented to avoid key checking #self.request_socket.curve_secretkey = server_secret #self.request_socket.curve_publickey = server_public #self.request_socket.curve_server = True self.request_socket.bind("tcp://{}:{}".format( self.client_address, self.request_port)) self.logger.info( "@@STRATUS-APP --> Bound authenticated request socket to client at {} on port: {}" .format(self.client_address, self.request_port)) except Exception as err: self.logger.error( "@@STRATUS-APP: Error initializing request socket on {}, port {}: {}" .format(self.client_address, self.request_port, err)) self.logger.error(traceback.format_exc()) def addHandler(self, clientId, jobId, handler): self.active_handlers[clientId + "-" + jobId] = handler return handler def removeHandler(self, clientId, jobId): handlerId = clientId + "-" + jobId try: del self.active_handlers[handlerId] except: self.logger.error("Error removing handler: " + handlerId + ", active handlers = " + str(list(self.active_handlers.keys()))) def setExeStatus(self, submissionId: str, status: Status): self.responder.setExeStatus(submissionId, status) def sendResponseMessage(self, msg: StratusResponse) -> str: request_args = [msg.id, msg.message] packaged_msg = "!".join(request_args) timeStamp = datetime.datetime.now().strftime("MM/dd HH:mm:ss") self.logger.info( "@@STRATUS-APP: Sending response {} on request_socket @({}): {}". format(msg.id, timeStamp, str(msg))) self.request_socket.send_string(packaged_msg) return packaged_msg def initInteractions(self): try: self.zmqContext: zmq.Context = zmq.Context() self.auth = ThreadAuthenticator(self.zmqContext) self.auth.start() self.auth.allow("192.168.0.22") self.auth.allow(self.client_address) self.auth.configure_curve( domain='*', location=zmq.auth.CURVE_ALLOW_ANY ) # self.public_keys_dir ) # Use 'location=zmq.auth.CURVE_ALLOW_ANY' for stonehouse security self.request_socket: zmq.Socket = self.zmqContext.socket(zmq.REP) self.responder = StratusZMQResponder( self.zmqContext, self.response_port, client_address=self.client_address, certificate_path=self.cert_dir) self.initSocket() self.logger.info( "@@STRATUS-APP:Listening for requests on port: {}".format( self.request_port)) except Exception as err: self.logger.error( "@@STRATUS-APP: ------------------------------- StratusApp Init error: {} ------------------------------- " .format(err)) def processResults(self): completed_workflows = self.responder.processWorkflows( self.getWorkflows()) for rid in completed_workflows: self.clearWorkflow(rid) def processRequests(self): while self.request_socket.poll(0) != 0: request_header = self.request_socket.recv_string().strip().strip( "'") parts = request_header.split("!") submissionId = str(parts[0]) rType = str(parts[1]) request: Dict = json.loads(parts[2]) if len(parts) > 2 else "" try: self.logger.info( "@@STRATUS-APP: ### Processing {} request: {}".format( rType, request)) if rType == "capabilities": response = self.core.getCapabilities(request["type"]) self.sendResponseMessage( StratusResponse(submissionId, response)) elif rType == "exe": if len(parts) <= 2: raise Exception("Missing parameters to exe request") request["rid"] = submissionId self.logger.info( "Processing zmq Request: '{}' '{}' '{}'".format( submissionId, rType, str(request))) self.submitWorkflow( request) # TODO: Send results when tasks complete. response = {"status": "Executing"} self.sendResponseMessage( StratusResponse(submissionId, response)) elif rType == "quit" or rType == "shutdown": response = {"status": "Terminating"} self.sendResponseMessage( StratusResponse(submissionId, response)) self.logger.info( "@@STRATUS-APP: Received Shutdown Message") exit(0) else: msg = "@@STRATUS-APP: Unknown request type: " + rType self.logger.info(msg) response = {"status": "error", "error": msg} self.sendResponseMessage( StratusResponse(submissionId, response)) except Exception as ex: self.processError(submissionId, ex) def processError(self, rid: str, ex: Exception): tb = traceback.format_exc() self.logger.error("@@STRATUS-APP: Execution error: " + str(ex)) self.logger.error(tb) response = {"status": "error", "error": str(ex), "traceback": tb} self.sendResponseMessage(StratusResponse(rid, response)) def updateInteractions(self): self.processRequests() self.processResults() def term(self, msg): self.logger.info("@@STRATUS-APP: !!EDAS Shutdown: " + msg) self.active = False self.auth.stop() self.logger.info("@@STRATUS-APP: QUIT PythonWorkerPortal") try: self.request_socket.close() except Exception: pass self.logger.info("@@STRATUS-APP: CLOSE request_socket") self.responder.close_connection() self.logger.info("@@STRATUS-APP: TERM responder") self.shutdown() self.logger.info("@@STRATUS-APP: shutdown complete")
class RpcServer: """""" def __init__(self): """ Constructor """ # Save functions dict: key is fuction name, value is fuction object self.__functions: Dict[str, Any] = {} # Zmq port related self.__context: zmq.Context = zmq.Context() # Reply socket (Request–reply pattern) self.__socket_rep: zmq.Socket = self.__context.socket(zmq.REP) # Publish socket (Publish–subscribe pattern) self.__socket_pub: zmq.Socket = self.__context.socket(zmq.PUB) # Worker thread related self.__active: bool = False # RpcServer status self.__thread: threading.Thread = None # RpcServer thread self.__lock: threading.Lock = threading.Lock() # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None def is_active(self) -> bool: """""" return self.__active def start(self, rep_address: str, pub_address: str, server_secretkey_path: str = "", username: str = "", password: str = "") -> None: """ Start RpcServer """ if self.__active: return # Start authenticator if server_secretkey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( server_secretkey_path) self.__socket_pub.curve_secretkey = secretkey self.__socket_pub.curve_publickey = publickey self.__socket_pub.curve_server = True self.__socket_rep.curve_secretkey = secretkey self.__socket_rep.curve_publickey = publickey self.__socket_rep.curve_server = True elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_pub.plain_server = True self.__socket_rep.plain_server = True # Bind socket address self.__socket_rep.bind(rep_address) self.__socket_pub.bind(pub_address) # Start RpcServer status self.__active = True # Start RpcServer thread self.__thread = threading.Thread(target=self.run) self.__thread.start() def stop(self) -> None: """ Stop RpcServer """ if not self.__active: return # Stop RpcServer status self.__active = False def join(self) -> None: # Wait for RpcServer thread to exit if self.__thread and self.__thread.is_alive(): self.__thread.join() self.__thread = None def run(self) -> None: """ Run RpcServer functions """ start = datetime.utcnow() while self.__active: # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) cur = datetime.utcnow() delta = cur - start if delta >= KEEP_ALIVE_INTERVAL: self.publish(KEEP_ALIVE_TOPIC, cur) if not self.__socket_rep.poll(1000): continue # Receive request data from Reply socket req = self.__socket_rep.recv_pyobj() # Get function name and parameters name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails try: func = self.__functions[name] r = func(*args, **kwargs) rep = [True, r] except Exception as e: # noqa rep = [False, traceback.format_exc()] # send callable response by Reply socket self.__socket_rep.send_pyobj(rep) # Unbind socket address self.__socket_pub.unbind(self.__socket_pub.LAST_ENDPOINT) self.__socket_rep.unbind(self.__socket_rep.LAST_ENDPOINT) if self.__authenticator: self.__authenticator.stop() def publish(self, topic: str, data: Any) -> None: """ Publish data """ with self.__lock: self.__socket_pub.send_pyobj([topic, data]) def register(self, func: Callable) -> None: """ Register function """ self.__functions[func.__name__] = func
class ZmqReceiver(object): def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None): self.context = zmq.Context() self.auth = None self.last_received_message = None self.is_running = False self.thread = None self.zmq_rep_bind_address = zmq_rep_bind_address self.zmq_sub_connect_addresses = zmq_sub_connect_addresses self.poller = zmq.Poller() self.sub_sockets = [] self.rep_socket = None if username is not None and password is not None: # Start an authenticator for this context. # Does not work on PUB/SUB as far as I (probably because the more secure solutions # require two way communication as well) self.auth = ThreadAuthenticator(self.context) self.auth.start() # Instruct authenticator to handle PLAIN requests self.auth.configure_plain(domain='*', passwords={username: password}) if self.zmq_sub_connect_addresses: for address in self.zmq_sub_connect_addresses: self.sub_sockets.append(SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec)) if zmq_rep_bind_address: self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth) # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds def stop(self): self.is_running = False logger.info("Closing pub and sub sockets...") if self.auth is not None: self.auth.stop() def run(self): self.is_running = True while self.is_running: socks = dict(self.poller.poll(1000)) logger.debug("Poll cycle over. checking sockets") if self.rep_socket: incoming_message = self.rep_socket.recv_string(socks) if incoming_message is not None: self.last_received_message = incoming_message try: logger.debug("Got info from REP socket") response_message = self.handle_incoming_message(incoming_message) self.rep_socket.send(response_message) except Exception as e: logger.error(e) for sub_socket in self.sub_sockets: incoming_message = sub_socket.recv_string(socks) if incoming_message is not None: if incoming_message != "zmq_sub_heartbeat": self.last_received_message = incoming_message logger.debug("Got info from SUB socket") try: self.handle_incoming_message(incoming_message) except Exception as e: logger.error(e) if self.rep_socket: self.rep_socket.destroy() for sub_socket in self.sub_sockets: sub_socket.destroy() def create_response_message(self, status_code, status_message, response_message): if response_message is not None: return json.dumps({"status_code": status_code, "status_message": status_message, "response_message": response_message}) else: return json.dumps({"status_code": status_code, "status_message": status_message}) def handle_incoming_message(self, message): if message != "zmq_sub_heartbeat": return self.create_response_message(200, "OK", None)
class ZmqReceiver(object): def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None): self.context = zmq.Context() self.auth = None self.last_received_message = None self.is_running = False self.thread = None self.zmq_rep_bind_address = zmq_rep_bind_address self.zmq_sub_connect_addresses = zmq_sub_connect_addresses self.poller = zmq.Poller() self.sub_sockets = [] self.rep_socket = None if username is not None and password is not None: # Start an authenticator for this context. # Does not work on PUB/SUB as far as I (probably because the more secure solutions # require two way communication as well) self.auth = ThreadAuthenticator(self.context) self.auth.start() # Instruct authenticator to handle PLAIN requests self.auth.configure_plain(domain='*', passwords={username: password}) if self.zmq_sub_connect_addresses: for address in self.zmq_sub_connect_addresses: self.sub_sockets.append( SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec)) if zmq_rep_bind_address: self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth) # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds def stop(self): self.is_running = False logger.info("Closing pub and sub sockets...") if self.auth is not None: self.auth.stop() def run(self): self.is_running = True while self.is_running: socks = dict(self.poller.poll(1000)) logger.debug("Poll cycle over. checking sockets") if self.rep_socket: incoming_message = self.rep_socket.recv_string(socks) if incoming_message is not None: self.last_received_message = incoming_message try: logger.debug("Got info from REP socket") response_message = self.handle_incoming_message( incoming_message) self.rep_socket.send(response_message) except Exception as e: logger.error(e) for sub_socket in self.sub_sockets: incoming_message = sub_socket.recv_string(socks) if incoming_message is not None: if incoming_message != "zmq_sub_heartbeat": self.last_received_message = incoming_message logger.debug("Got info from SUB socket") try: self.handle_incoming_message(incoming_message) except Exception as e: logger.error(e) if self.rep_socket: self.rep_socket.destroy() for sub_socket in self.sub_sockets: sub_socket.destroy() def create_response_message(self, status_code, status_message, response_message): if response_message is not None: return json.dumps({ "status_code": status_code, "status_message": status_message, "response_message": response_message }) else: return json.dumps({ "status_code": status_code, "status_message": status_message }) def handle_incoming_message(self, message): if message != "zmq_sub_heartbeat": return self.create_response_message(200, "OK", None)
class Command(LAVADaemonCommand): help = "LAVA log recorder" logger = None default_logfile = "/var/log/lava-server/lava-logs.log" def __init__(self, *args, **options): super().__init__(*args, **options) self.logger = logging.getLogger("lava-logs") self.log_socket = None self.auth = None self.controler = None self.inotify_fd = None self.pipe_r = None self.poller = None self.cert_dir_path = None # List of logs self.jobs = {} # Keep test cases in memory self.test_cases = [] # Master status self.last_ping = 0 self.ping_interval = TIMEOUT def add_arguments(self, parser): super().add_arguments(parser) net = parser.add_argument_group("network") net.add_argument('--socket', default='tcp://*:5555', help="Socket waiting for logs. Default: tcp://*:5555") net.add_argument('--master-socket', default='tcp://localhost:5556', help="Socket for master-slave communication. Default: tcp://localhost:5556") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument('--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-logs", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return filename = os.path.join(settings.MEDIA_ROOT, 'lava-logs-config.yaml') self.logger.debug("[INIT] Dumping config to %s", filename) with open(filename, 'w') as output: yaml.dump(options, output) # Create the sockets context = zmq.Context() self.log_socket = context.socket(zmq.PULL) self.controler = context.socket(zmq.ROUTER) self.controler.setsockopt(zmq.IDENTITY, b"lava-logs") # Limit the number of messages in the queue self.controler.setsockopt(zmq.SNDHWM, 2) # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5 # "Immediately readies that connection for data transfer with the master" self.controler.setsockopt(zmq.CONNECT_RID, b"master") if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.log_socket.setsockopt(zmq.IPV6, 1) self.controler.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate(options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except OSError as err: self.logger.error("[INIT] %s", err) self.auth.stop() return self.log_socket.curve_publickey = master_public self.log_socket.curve_secretkey = master_secret self.log_socket.curve_server = True self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_serverkey = master_public self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.cert_dir_path = options["slaves_certs"] self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.log_socket.bind(options['socket']) self.controler.connect(options['master_socket']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.log_socket, zmq.POLLIN) self.poller.register(self.controler, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] listening for logs") # PING right now: the master is waiting for this message to start # scheduling. self.controler.send_multipart([b"master", b"PING"]) try: self.main_loop() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) # Close the controler socket self.controler.close(linger=0) self.poller.unregister(self.controler) # Carefully close the logging socket as we don't want to lose messages self.logger.info("[EXIT] Disconnect logging socket and process messages") endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT)) self.logger.debug("[EXIT] unbinding from '%s'", endpoint) self.log_socket.unbind(endpoint) # Empty the queue try: while self.wait_for_messages(True): # Flush test cases cache for every iteration because we might # get killed soon. self.flush_test_cases() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Last flush self.flush_test_cases() self.logger.info("[EXIT] Closing the logging socket: the queue is empty") self.log_socket.close() if options['encrypt']: self.auth.stop() context.term() def flush_test_cases(self): if not self.test_cases: return # Try to save into the database try: TestCase.objects.bulk_create(self.test_cases) self.logger.info("Saving %d test cases", len(self.test_cases)) self.test_cases = [] except DatabaseError as exc: self.logger.error("Unable to flush the test cases") self.logger.exception(exc) self.logger.warning("Saving test cases one by one and dropping the faulty ones") saved = 0 for tc in self.test_cases: with contextlib.suppress(DatabaseError): tc.save() saved += 1 self.logger.info("%d test cases saved, %d dropped", saved, len(self.test_cases) - saved) self.test_cases = [] def main_loop(self): last_gc = time.time() last_bulk_create = time.time() # Wait for messages # TODO: fix timeout computation while self.wait_for_messages(False): now = time.time() # Dump TestCase into the database if now - last_bulk_create > BULK_CREATE_TIMEOUT: last_bulk_create = now self.flush_test_cases() # Close old file handlers if now - last_gc > FD_TIMEOUT: last_gc = now # Iterate while removing keys is not compatible with iterator for job_id in list(self.jobs.keys()): # pylint: disable=consider-iterating-dictionary if now - self.jobs[job_id].last_usage > FD_TIMEOUT: self.logger.info("[%s] closing log file", job_id) self.jobs[job_id].close() del self.jobs[job_id] # Ping the master if now - self.last_ping > self.ping_interval: self.logger.debug("PING => master") self.last_ping = now self.controler.send_multipart([b"master", b"PING"]) def wait_for_messages(self, leaving): try: try: sockets = dict(self.poller.poll(TIMEOUT * 1000)) except zmq.error.ZMQError as exc: self.logger.error("[POLL] zmq error: %s", str(exc)) return True # Messages if sockets.get(self.log_socket) == zmq.POLLIN: self.logging_socket() return True # Signals elif sockets.get(self.pipe_r) == zmq.POLLIN: # remove the message from the queue os.read(self.pipe_r, 1) if not leaving: self.logger.info("[POLL] received a signal, leaving") return False else: self.logger.warning("[POLL] signal already handled, please wait for the process to exit") return True # Pong received elif sockets.get(self.controler) == zmq.POLLIN: self.controler_socket() return True # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", self.cert_dir_path) self.auth.configure_curve(domain='*', location=self.cert_dir_path) # Nothing received else: return not leaving except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset") connection.close() return True def logging_socket(self): msg = self.log_socket.recv_multipart() try: (job_id, message) = (u(m) for m in msg) # pylint: disable=unbalanced-tuple-unpacking except ValueError: # do not let a bad message stop the master. self.logger.error("[POLL] failed to parse log message, skipping: %s", msg) return try: scanned = yaml.load(message, Loader=yaml.CLoader) except yaml.YAMLError: self.logger.error("[%s] data are not valid YAML, dropping", job_id) return # Look for "results" level try: message_lvl = scanned["lvl"] message_msg = scanned["msg"] except TypeError: self.logger.error("[%s] not a dictionary, dropping", job_id) return except KeyError: self.logger.error( "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s", job_id, message) return # Find the handler (if available) if job_id not in self.jobs: # Query the database for the job try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return self.logger.info("[%s] receiving logs from a new job", job_id) # Create the sub directories (if needed) mkdir(job.output_dir) self.jobs[job_id] = JobHandler(job) # For 'event', send an event and log as 'debug' if message_lvl == 'event': self.logger.debug("[%s] event: %s", job_id, message_msg) send_event(".event", "lavaserver", {"message": message_msg, "job": job_id}) message_lvl = "debug" # For 'marker', save in the database and log as 'debug' elif message_lvl == 'marker': # TODO: save on the file system in case of lava-logs restart m_type = message_msg.get("type") case = message_msg.get("case") if m_type is None or case is None: self.logger.error("[%s] invalid marker: %s", job_id, message_msg) return self.jobs[job_id].markers.setdefault(case, {})[m_type] = self.jobs[job_id].line_count() # This is in fact the previous line self.jobs[job_id].markers[case][m_type] -= 1 self.logger.debug("[%s] marker: %s line: %s", job_id, message_msg, self.jobs[job_id].markers[case][m_type]) return # Mark the file handler as used self.jobs[job_id].last_usage = time.time() # The format is a list of dictionaries self.jobs[job_id].write("- %s" % message) if message_lvl == "results": try: job = TestJob.objects.get(pk=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return meta_filename = create_metadata_store(message_msg, job) new_test_case = map_scanned_results(results=message_msg, job=job, markers=self.jobs[job_id].markers, meta_filename=meta_filename) if new_test_case is None: self.logger.warning( "[%s] unable to map scanned results: %s", job_id, message) else: self.test_cases.append(new_test_case) # Look for lava.job result if message_msg.get("definition") == "lava" and message_msg.get("case") == "job": # Flush cached test cases self.flush_test_cases() if message_msg.get("result") == "pass": health = TestJob.HEALTH_COMPLETE health_msg = "Complete" else: health = TestJob.HEALTH_INCOMPLETE health_msg = "Incomplete" self.logger.info("[%s] job status: %s", job_id, health_msg) infrastructure_error = (message_msg.get("error_type") in ["Bug", "Configuration", "Infrastructure"]) if infrastructure_error: self.logger.info("[%s] Infrastructure error", job_id) # Update status. with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(health, infrastructure_error) job.save() # n.b. logging here would produce a log entry for every message in every job. def controler_socket(self): msg = self.controler.recv_multipart() try: master_id = u(msg[0]) action = u(msg[1]) ping_interval = int(msg[2]) if master_id != "master": self.logger.error("Invalid master id '%s'. Should be 'master'", master_id) return if action != "PONG": self.logger.error("Invalid answer '%s'. Should be 'PONG'", action) return except (IndexError, ValueError): self.logger.error("Invalid message '%s'", msg) return if ping_interval < TIMEOUT: self.logger.error("invalid ping interval (%d) too small", ping_interval) return self.logger.debug("master => PONG(%d)", ping_interval) self.ping_interval = ping_interval
class Command(LAVADaemonCommand): """ worker_host is the hostname of the worker this field is set by the admin and could therefore be empty in a misconfigured instance. """ logger = None help = "LAVA dispatcher master" default_logfile = "/var/log/lava-server/lava-master.log" def __init__(self, *args, **options): super(Command, self).__init__(*args, **options) self.auth = None self.controler = None self.event_socket = None self.poller = None self.pipe_r = None self.inotify_fd = None # List of logs # List of known dispatchers. At startup do not load this from the # database. This will help to know if the slave as restarted or not. self.dispatchers = {"lava-logs": SlaveDispatcher("lava-logs", online=False)} self.events = {"canceling": set()} def add_arguments(self, parser): super(Command, self).add_arguments(parser) # Important: ensure share/env.yaml is put into /etc/ by setup.py in packaging. config = parser.add_argument_group("dispatcher config") config.add_argument('--env', default="/etc/lava-server/env.yaml", help="Environment variables for the dispatcher processes. " "Default: /etc/lava-server/env.yaml") config.add_argument('--env-dut', default="/etc/lava-server/env.dut.yaml", help="Environment variables for device under test. " "Default: /etc/lava-server/env.dut.yaml") config.add_argument('--dispatchers-config', default="/etc/lava-server/dispatcher.d", help="Directory that might contain dispatcher specific configuration") net = parser.add_argument_group("network") net.add_argument('--master-socket', default='tcp://*:5556', help="Socket for master-slave communication. Default: tcp://*:5556") net.add_argument('--event-url', default="tcp://localhost:5500", help="URL of the publisher") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument('--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def send_status(self, hostname): """ The master crashed, send a STATUS message to get the current state of jobs """ jobs = TestJob.objects.filter(actual_device__worker_host__hostname=hostname, state=TestJob.STATE_RUNNING) for job in jobs: self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname, job.actual_device.hostname) send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)]) def dispatcher_alive(self, hostname): if hostname not in self.dispatchers: # The server crashed: send a STATUS message self.logger.warning("Unknown dispatcher <%s> (server crashed)", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) self.send_status(hostname) # Mark the dispatcher as alive self.dispatchers[hostname].alive() def controler_socket(self): try: # We need here to use the zmq.NOBLOCK flag, otherwise we could block # the whole main loop where this function is called. msg = self.controler.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False # This is way to verbose for production and should only be activated # by (and for) developers # self.logger.debug("[CC] Receiving: %s", msg) # 1: the hostname (see ZMQ documentation) hostname = u(msg[0]) # 2: the action action = u(msg[1]) # Check that lava-logs only send PINGs if hostname == "lava-logs" and action != "PING": self.logger.error("%s => %s Invalid action from log daemon", hostname, action) return True # Handle the actions if action == 'HELLO' or action == 'HELLO_RETRY': self._handle_hello(hostname, action, msg) elif action == 'PING': self._handle_ping(hostname, action, msg) elif action == 'END': self._handle_end(hostname, action, msg) elif action == 'START_OK': self._handle_start_ok(hostname, action, msg) else: self.logger.error("<%s> sent unknown action=%s, args=(%s)", hostname, action, msg[1:]) return True def read_event_socket(self): try: msg = self.event_socket.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False try: (topic, _, dt, username, data) = (u(m) for m in msg) except ValueError: self.logger.error("Invalid event: %s", msg) return True if topic.endswith(".testjob"): try: data = simplejson.loads(data) if data["state"] == "Canceling": self.events["canceling"].add(int(data["job"])) except ValueError: self.logger.error("Invalid event data: %s", msg) return True def _handle_end(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) error_msg = msg[3] compressed_description = msg[4] except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) # ACK even if the job is unknown to let the dispatcher # forget about it send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) return filename = os.path.join(job.output_dir, 'description.yaml') # If description.yaml already exists: a END was already received if os.path.exists(filename): self.logger.info("[%d] %s => END (duplicated), skipping", job_id, hostname) else: if compressed_description: self.logger.info("[%d] %s => END", job_id, hostname) else: self.logger.info("[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)", job_id, hostname) with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(TestJob.HEALTH_INCOMPLETE) if error_msg: self.logger.error("[%d] Error: %s", job_id, error_msg) job.failure_comment = error_msg job.save() # Create description.yaml even if it's empty # Allows to know when END messages are duplicated try: # Create the directory if it was not already created mkdir(os.path.dirname(filename)) # TODO: check that compressed_description is not "" description = lzma.decompress(compressed_description) with open(filename, 'w') as f_description: f_description.write(description.decode("utf-8")) if description: parse_job_description(job) except (IOError, lzma.LZMAError) as exc: self.logger.error("[%d] Unable to dump 'description.yaml'", job_id) self.logger.exception("[%d] %s", job_id, exc) # ACK the job and mark the dispatcher as alive send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) self.dispatcher_alive(hostname) def _handle_hello(self, hostname, action, msg): # Check the protocol version try: slave_version = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("%s => %s", hostname, action) if slave_version != PROTOCOL_VERSION: self.logger.error("<%s> using protocol v%d while master is using v%d", hostname, slave_version, PROTOCOL_VERSION) return send_multipart_u(self.controler, [hostname, 'HELLO_OK']) # If the dispatcher is known and sent an HELLO, means that # the slave has restarted if hostname in self.dispatchers: if action == 'HELLO': self.logger.warning("Dispatcher <%s> has RESTARTED", hostname) else: # Assume the HELLO command was received, and the # action succeeded. self.logger.warning("Dispatcher <%s> was not confirmed", hostname) else: # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO # message. self.logger.warning("New dispatcher <%s>", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) # Mark the dispatcher as alive self.dispatcher_alive(hostname) def _handle_ping(self, hostname, action, msg): # pylint: disable=unused-argument self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL) # Send back a signal send_multipart_u(self.controler, [hostname, 'PONG', str(PING_INTERVAL)]) self.dispatcher_alive(hostname) def _handle_start_ok(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("[%d] %s => START_OK", job_id, hostname) try: with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_running() job.save() except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) else: self.dispatcher_alive(hostname) def export_definition(self, job): # pylint: disable=no-self-use job_def = yaml.load(job.definition) job_def['compatibility'] = job.pipeline_compatibility # no need for the dispatcher to retain comments return yaml.dump(job_def) def save_job_config(self, job, worker, device_cfg, options): output_dir = job.output_dir mkdir(output_dir) with open(os.path.join(output_dir, "job.yaml"), "w") as f_out: f_out.write(self.export_definition(job)) with contextlib.suppress(IOError): shutil.copy(options["env"], os.path.join(output_dir, "env.yaml")) with contextlib.suppress(IOError): shutil.copy(options["env_dut"], os.path.join(output_dir, "env.dut.yaml")) with contextlib.suppress(IOError): shutil.copy(os.path.join(options["dispatchers_config"], "%s.yaml" % worker.hostname), os.path.join(output_dir, "dispatcher.yaml")) with open(os.path.join(output_dir, "device.yaml"), "w") as f_out: yaml.dump(device_cfg, f_out) def start_job(self, job, options): # Load job definition to get the variables for template # rendering job_def = yaml.load(job.definition) job_ctx = job_def.get('context', {}) device = job.actual_device worker = device.worker_host # Load configurations env_str = load_optional_yaml_file(options['env']) env_dut_str = load_optional_yaml_file(options['env_dut']) device_cfg = device.load_configuration(job_ctx) dispatcher_cfg_file = os.path.join(options['dispatchers_config'], "%s.yaml" % worker.hostname) dispatcher_cfg = load_optional_yaml_file(dispatcher_cfg_file) self.save_job_config(job, worker, device_cfg, options) self.logger.info("[%d] START => %s (%s)", job.id, worker.hostname, device.hostname) send_multipart_u(self.controler, [worker.hostname, 'START', str(job.id), self.export_definition(job), yaml.dump(device_cfg), dispatcher_cfg, env_str, env_dut_str]) # For multinode jobs, start the dynamic connections parent = job for sub_job in job.sub_jobs_list: if sub_job == parent or not sub_job.dynamic_connection: continue # inherit only enough configuration for dynamic_connection operation self.logger.info("[%d] Trimming dynamic connection device configuration.", sub_job.id) min_device_cfg = parent.actual_device.minimise_configuration(device_cfg) self.save_job_config(sub_job, worker, min_device_cfg, options) self.logger.info("[%d] START => %s (connection)", sub_job.id, worker.hostname) send_multipart_u(self.controler, [worker.hostname, 'START', str(sub_job.id), self.export_definition(sub_job), yaml.dump(min_device_cfg), dispatcher_cfg, env_str, env_dut_str]) def start_jobs(self, options): """ Loop on all scheduled jobs and send the START message to the slave. """ # make the request atomic query = TestJob.objects.select_for_update() # Only select test job that are ready query = query.filter(state=TestJob.STATE_SCHEDULED) # Only start jobs on online workers query = query.filter(actual_device__worker_host__state=Worker.STATE_ONLINE) # exclude test job without a device: they are special test jobs like # dynamic connection. query = query.exclude(actual_device=None) # TODO: find a way to lock actual_device # Loop on all jobs for job in query: msg = None try: self.start_job(job, options) except jinja2.TemplateNotFound as exc: self.logger.error("[%d] Template not found: '%s'", job.id, exc.message) msg = "Template not found: '%s'" % exc.message except jinja2.TemplateSyntaxError as exc: self.logger.error("[%d] Template syntax error in '%s', line %d: %s", job.id, exc.name, exc.lineno, exc.message) msg = "Template syntax error in '%s', line %d: %s" % (exc.name, exc.lineno, exc.message) except IOError as exc: self.logger.error("[%d] Unable to read '%s': %s", job.id, exc.filename, exc.strerror) msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror) except yaml.YAMLError as exc: self.logger.error("[%d] Unable to parse job definition: %s", job.id, exc) msg = "Cannot parse job definition: %s" % exc if msg: # Add the error as lava.job result metadata = {"case": "job", "definition": "lava", "error_type": "Infrastructure", "error_msg": msg, "result": "fail"} suite, _ = TestSuite.objects.get_or_create(name="lava", job=job) TestCase.objects.create(name="job", suite=suite, result=TestCase.RESULT_FAIL, metadata=yaml.dump(metadata)) job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True) job.save() def cancel_jobs(self, partial=False): query = TestJob.objects.filter(state=TestJob.STATE_CANCELING) if partial: query = query.filter(id__in=list(self.events["canceling"])) for job in query: worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host self.logger.info("[%d] CANCEL => %s", job.id, worker.hostname) send_multipart_u(self.controler, [worker.hostname, 'CANCEL', str(job.id)]) def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-master", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return self.logger.info("[INIT] Marking all workers as offline") with transaction.atomic(): for worker in Worker.objects.select_for_update().all(): worker.go_state_offline() worker.save() # Create the sockets context = zmq.Context() self.controler = context.socket(zmq.ROUTER) self.event_socket = context.socket(zmq.SUB) if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.controler.setsockopt(zmq.IPV6, 1) self.event_socket.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate(options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except IOError as err: self.logger.error(err) self.auth.stop() return self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_server = True self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.controler.setsockopt(zmq.IDENTITY, b"master") # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42 # "If two clients use the same identity when connecting to a ROUTER # [...] the ROUTER socket shall hand-over the connection to the new # client and disconnect the existing one." self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1) self.controler.bind(options['master_socket']) self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC)) self.event_socket.connect(options['event_url']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.controler, zmq.POLLIN) self.poller.register(self.event_socket, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] LAVA master has started.") self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION) try: self.main_loop(options) except BaseException as exc: self.logger.error("[CLOSE] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Drop controler socket: the protocol does handle lost messages self.logger.info("[CLOSE] Closing the controler socket and dropping messages") self.controler.close(linger=0) self.event_socket.close(linger=0) if options['encrypt']: self.auth.stop() context.term() def main_loop(self, options): last_schedule = last_dispatcher_check = time.time() while True: try: try: # Compute the timeout now = time.time() timeout = min(SCHEDULE_INTERVAL - (now - last_schedule), PING_INTERVAL - (now - last_dispatcher_check)) # If some actions are remaining, decrease the timeout if self.events["canceling"]: timeout = min(timeout, 1) # Wait at least for 1ms timeout = max(timeout * 1000, 1) # Wait for data or a timeout sockets = dict(self.poller.poll(timeout)) except zmq.error.ZMQError: continue if sockets.get(self.pipe_r) == zmq.POLLIN: self.logger.info("[POLL] Received a signal, leaving") break # Command socket if sockets.get(self.controler) == zmq.POLLIN: while self.controler_socket(): # Unqueue all pending messages pass # Events socket if sockets.get(self.event_socket) == zmq.POLLIN: while self.read_event_socket(): # Unqueue all pending messages pass # Wait for the next iteration to handle the event. # In fact, the code that generated the event (lava-logs or # lava-server-gunicorn) needs some time to commit the # database transaction. # If we are too fast, the database object won't be # available (or in the right state) yet. continue # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) # Check dispatchers status now = time.time() if now - last_dispatcher_check > PING_INTERVAL: for hostname, dispatcher in self.dispatchers.items(): if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT: if hostname == "lava-logs": self.logger.error("[STATE] lava-logs goes OFFLINE") else: self.logger.error("[STATE] Dispatcher <%s> goes OFFLINE", hostname) self.dispatchers[hostname].go_offline() last_dispatcher_check = now # Limit accesses to the database. This will also limit the rate of # CANCEL and START messages if time.time() - last_schedule > SCHEDULE_INTERVAL: if self.dispatchers["lava-logs"].online: schedule(self.logger) # Dispatch scheduled jobs with transaction.atomic(): self.start_jobs(options) else: self.logger.warning("lava-logs is offline: can't schedule jobs") # Handle canceling jobs self.cancel_jobs() # Do not count the time taken to schedule jobs last_schedule = time.time() else: # Cancel the jobs and remove the jobs from the set if self.events["canceling"]: self.cancel_jobs(partial=True) self.events["canceling"] = set() except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset.") # Closing the database connection will force Django to reopen # the connection connection.close() time.sleep(2)
class ZmqReceiver(ZmqBase): ''' A ZmqReceiver class will listen on a REP or SUB socket for messages and will call the 'handle_incoming_message()' method to process it. Subclasses should override that. A response must be implemented for REP sockets, but is useless for SUB sockets. ''' def __init__(self, zmq_rep_bind_address: Optional[str] = None, zmq_sub_connect_addresses: Tuple[SubSocketAddress, ...] = None, recreate_timeout: Optional[int] = 600, username: Optional[str] = None, password: Optional[str] = None): super().__init__() self.__context = zmq.Context() self.__poller = zmq.Poller() self.__sub_sockets = tuple( SubSocket( ctx=self.__context, poller=self.__poller, address=address, timeout_in_sec=recreate_timeout, ) for address in (zmq_sub_connect_addresses or tuple())) self.__auth: Optional[ThreadAuthenticator] = None if username is not None and password is not None: # Start an authenticator for this context. # Does not work on PUB/SUB as far as I know (probably because the # more secure solutions require two way communication as well) self.__auth = ThreadAuthenticator(self.__context) self.__auth.start() # Instruct authenticator to handle PLAIN requests self.__auth.configure_plain(domain='*', passwords={ username: password, }) self.__rep_socket = RepSocket( ctx=self.__context, poller=self.__poller, address=zmq_rep_bind_address, auth=self.__auth, ) if zmq_rep_bind_address else None self.__last_received_message = None self.__is_running = False @property def is_running(self) -> bool: return self.__is_running def stop(self) -> None: ''' May take up to 60 seconds to actually stop since poller has timeout of 60 seconds ''' self._info('Closing pub and sub sockets...') self.__is_running = False if self.__auth is not None: self.__auth.stop() def _run_rep_socket(self, socks) -> None: if self.__rep_socket is None: return incoming_message = self.__rep_socket.recv_string(socks) if incoming_message is None: return if incoming_message != self.HEARTBEAT_MSG: self.__last_received_message = incoming_message self._debug('Got info from REP socket') try: response_message = self.handle_incoming_message(incoming_message, ) self.__rep_socket.send(response_message) except Exception as e: self._error(e) def _run_sub_sockets(self, socks) -> None: for sub_socket in self.__sub_sockets: incoming_message = sub_socket.recv_string(socks) if incoming_message is None: continue if incoming_message != self.HEARTBEAT_MSG: self.__last_received_message = incoming_message self._debug('Got info from SUB socket') try: self.handle_incoming_message(incoming_message) except Exception as e: self._error(e) def run(self) -> None: self.__is_running = True while self.__is_running: try: socks = dict(self.__poller.poll(1000)) except BaseException as e: self._error(e) continue self._debug('Poll cycle over. checking sockets') self._run_rep_socket(socks) self._run_sub_sockets(socks) if self.__rep_socket: self.__rep_socket.destroy() for sub_socket in self.__sub_sockets: sub_socket.destroy() def create_response_message(self, status_code: int, status_message: str, response_message: Optional[str] = None) -> str: payload = { self.STATUS_CODE: status_code, self.STATUS_MSG: status_message, } if response_message is not None: payload[self.RESPONSE_MSG] = response_message return json.dumps(payload) def handle_incoming_message(self, message: str) -> Optional[str]: if message == self.HEARTBEAT_MSG: return None return self.create_response_message( status_code=self.STATUS_CODE_OK, status_message=self.STATUS_MSG_OK, ) def get_last_received_message(self) -> Optional[str]: return self.__last_received_message def get_sub_socket(self, idx: int) -> SubSocket: return self.__sub_sockets[idx]
class MultiNodeAgent(BEMOSSAgent): '''Listens to everything and publishes a heartbeat according to the heartbeat period specified in the settings module. ''' def __init__(self, config_path, **kwargs): super(MultiNodeAgent, self).__init__(**kwargs) #self.node_health = dict() #self.node_last_sync = dict() self.agent_id = 'multinodeagent' self.identity = self.agent_id self.multinode_status = dict() self.is_parent = False self.last_sync_with_parent = datetime(1991, 1, 1) #equivalent to -ve infinitive self.parent_node = None self.curcon = None #initialize database connection. self.recently_online_node_list = [] # initialize to lists to empty self.recently_offline_node_list = [ ] # they will be filled as nodes are discovered to be online/offline self.offline_variables = offline_variables self.offline_variables['logged_by'] = self.agent_id self.offline_table = offline_table self.offline_log_variables = offline_log_variables def getMultinodeData(self): self.multinode_data = db_helper.get_multinode_data() self.nodelist_dict = { node['name']: node for node in self.multinode_data['known_nodes'] } self.node_name_list = [ node['name'] for node in self.multinode_data['known_nodes'] ] self.address_list = [ node['address'] for node in self.multinode_data['known_nodes'] ] self.server_key_list = [ node['server_key'] for node in self.multinode_data['known_nodes'] ] self.node_name = self.multinode_data['this_node'] for index, node in enumerate(self.multinode_data['known_nodes']): if node['name'] == self.node_name: self.node_index = index break else: raise ValueError( '"this_node:" entry on the multinode_data json file is invalid' ) for node_name in self.node_name_list: #initialize all nodes data if node_name not in self.multinode_status: #initialize new nodes. There could be already the node if this getMultiNode # data is being called later self.multinode_status[node_name] = dict() self.multinode_status[node_name][ 'health'] = -10 #initialized; never online/offline self.multinode_status[node_name]['last_sync_time'] = datetime( 1991, 1, 1) self.multinode_status[node_name]['last_online_time'] = None self.multinode_status[node_name]['last_offline_time'] = None self.multinode_status[node_name]['last_scanned_time'] = None def configure_authenticator(self): self.auth.allow() # Tell authenticator to use the certificate in a directory self.auth.configure_curve(domain='*', location=self.public_keys_dir) @Core.receiver('onsetup') def onsetup(self, sender, **kwargs): print "Setup" self.getMultinodeData() base_dir = settings.PROJECT_DIR + "/Agents/MultiNodeAgent/" public_keys_dir = os.path.abspath(os.path.join(base_dir, 'public_keys')) secret_keys_dir = os.path.abspath( os.path.join(base_dir, 'private_keys')) self.secret_keys_dir = secret_keys_dir self.public_keys_dir = public_keys_dir if not (os.path.exists(public_keys_dir) and os.path.exists(secret_keys_dir)): logging.critical( "Certificates are missing - run generate_certificates.py script first" ) sys.exit(1) ctx = zmq.Context.instance() self.ctx = ctx # Start an authenticator for this context. self.auth = ThreadAuthenticator(ctx) self.auth.start() self.configure_authenticator() server = ctx.socket(zmq.PUB) server_secret_key_filename = self.multinode_data['known_nodes'][ self.node_index]['server_secret_key'] server_secret_file = os.path.join(secret_keys_dir, server_secret_key_filename) server_public, server_secret = zmq.auth.load_certificate( server_secret_file) server.curve_secretkey = server_secret server.curve_publickey = server_public server.curve_server = True # must come before bind server.bind( self.multinode_data['known_nodes'][self.node_index]['address']) self.server = server def check_if_parent(self): if self.node_name == self.node_name_list[ 0]: #The first entry is the parent; always self.is_parent = True self.node_index = 0 print "I am the boss now, " + self.node_name # start the web-server subprocess.check_output(settings.PROJECT_DIR + "/start_webserver.sh " + settings.PROJECT_DIR, shell=True) message = dict() message[STATUS_CHANGE.AGENT_ID] = 'devicediscoveryagent' message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE.AGENT_STATUS] = 'start' message[STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT self.bemoss_publish('status_change', 'networkagent', [message]) self.updateParent(self.node_name) print "discoveryagent started" def disperseMessage(self, topic, header, message): for node_name in self.node_name_list: if node_name == self.node_name: continue self.server.send( jsonify(node_name + '/republish/' + topic, message)) def republishToParent(self, topic, header, message): if self.is_parent: return #if I am parent, the message is already published for node_name in self.node_name_list: if self.multinode_status[node_name][ 'health'] == 2: #health = 2 is the parent node self.server.send( jsonify(node_name + '/republish/' + topic, message)) @Core.periodic(20) def send_heartbeat(self): # self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'}) # print 'publishing' print "Sending heartbeat" last_sync_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') self.server.send( jsonify( 'heartbeat/' + self.node_name + '/' + str(self.is_parent) + '/' + last_sync_string, "")) def extract_ip(self, addr): return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})', addr).groups()[0] def getNodeId(self, node_name): for index, node in enumerate(self.multinode_data['known_nodes']): if node['name'] == node_name: node_index = index break else: raise ValueError('the node name: ' + node_name + ' is not found in multinode data') return node_index def getNodeName(self, node_id): return self.multinode_data['known_nodes'][node_id]['name'] def handle_offline_nodes(self, node_name_list): if self.is_parent: # start all the agents belonging to that node on this node command_group = [] for node_name in node_name_list: node_id = self.getNodeId(node_name) #put the offline event into cassandra events log table, and also create notification self.offline_variables['date_id'] = str(datetime.now().date()) self.offline_variables['time'] = datetime.utcnow() self.offline_variables['agent_id'] = node_name self.offline_variables['event'] = 'node-offline' self.offline_variables['reason'] = 'communication-error' self.offline_variables['related_to'] = None self.offline_variables['event_id'] = uuid.uuid4() self.offline_variables['logged_time'] = datetime.utcnow() self.TSDCustomInsert(all_vars=self.offline_variables, log_vars=self.offline_log_variables, tablename=self.offline_table) time = date_converter.UTCToLocal(datetime.utcnow()) message = str( node_name ) + ': ' + 'node-offline. Reason: possibly communiation-error' self.curcon.execute( "select id from possible_events where event_name=%s", ('node-offline', )) event_id = self.curcon.fetchone()[0] self.curcon.execute( "insert into notification (dt_triggered, seen, event_type_id, message) VALUES (%s, %s, %s, %s)", (time, False, event_id, message)) self.curcon.commit() # get a list of agents that were supposedly running in that offline node self.curcon.execute( "SELECT agent_id FROM " + node_devices_table + " WHERE assigned_node_id=%s", (node_id, )) if self.curcon.rowcount: agent_ids = self.curcon.fetchall() for agent_id in agent_ids: message = dict() message[STATUS_CHANGE.AGENT_ID] = agent_id[0] message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE.AGENT_STATUS] = 'start' message[ STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY command_group += [message] print "moving agents from offline node to parent: " + str( node_name_list) print command_group if command_group: self.bemoss_publish('status_change', 'networkagent', command_group) def handle_online_nodes(self, node_name_list): if self.is_parent: # start all the agents belonging to that nodes back on them command_group = [] for node_name in node_name_list: node_id = self.getNodeId(node_name) # put the online event into cassandra events log table, and also create notification self.offline_variables['date_id'] = str(datetime.now().date()) self.offline_variables['time'] = datetime.utcnow() self.offline_variables['agent_id'] = node_name self.offline_variables['event'] = 'node-online' self.offline_variables['reason'] = 'communication-restored' self.offline_variables['related_to'] = None self.offline_variables['event_id'] = uuid.uuid4() self.offline_variables['logged_time'] = datetime.utcnow() self.TSDCustomInsert(all_vars=self.offline_variables, log_vars=self.offline_log_variables, tablename=self.offline_table) time = date_converter.UTCToLocal(datetime.utcnow()) message = str( node_name ) + ': ' + 'node-online. Reason: possibly communiation-restored' self.curcon.execute( "select id from possible_events where event_name=%s", ('node-online', )) event_id = self.curcon.fetchone()[0] self.curcon.execute( "insert into notification (dt_triggered, seen, event_type_id, message) VALUES (%s, %s, %s, %s)", (time, False, event_id, message)) self.curcon.commit() #get a list of agents that were supposed to be running in that online node self.curcon.execute( "SELECT agent_id FROM " + node_devices_table + " WHERE assigned_node_id=%s", (node_id, )) if self.curcon.rowcount: agent_ids = self.curcon.fetchall() for agent_id in agent_ids: message = dict() message[STATUS_CHANGE.AGENT_ID] = agent_id[0] message[ STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE. AGENT_STATUS] = 'stop' #stop in this node command_group += [message] message = dict(message) #create another copy message[STATUS_CHANGE.NODE] = str(node_id) message[ STATUS_CHANGE. AGENT_STATUS] = 'start' #start in the target node command_group += [message] print "Moving agents back to the online node: " + str( node_name_list) print command_group if command_group: self.bemoss_publish('status_change', 'networkagent', command_group) def updateParent(self, parent_node_name): parent_ip = self.extract_ip( self.nodelist_dict[parent_node_name]['address']) write_new = False if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE ): # but parent file doesn't exists write_new = True else: with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f: read_ip = f.read() if read_ip != parent_ip: write_new = True if write_new: with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f: f.write(parent_ip) if self.curcon: self.curcon.close() #close old connection self.curcon = db_connection( ) #start new connection using new parent_ip self.vip.pubsub.publish('pubsub', 'from/multinodeagent/update_parent') @Core.periodic(60) def check_health(self): for node_name, node in self.multinode_status.items(): if node['health'] > 0: #initialize all online nodes to 0. If they are really online, they should change it # back to 1 or 2 (parent) within 30 seconds throught the heartbeat. node['health'] = 0 gevent.sleep(30) parent_node_name = None #initialize parent node online_node_exists = False for node_name, node in self.multinode_status.items(): node['last_scanned_time'] = datetime.now() if node['health'] == 0: node['health'] = -1 node['last_offline_time'] = datetime.now() self.recently_offline_node_list += [node_name] elif node['health'] == -1: #offline since long pass elif node[ 'health'] == -10: #The node was initialized to -10, and never came online. Treat it as recently going # offline for this iteration so that the agents that were supposed to be running there can be migrated node['health'] = -1 self.recently_offline_node_list += [node_name] elif node['health'] == 2: #there is some parent node present parent_node_name = node_name if node['health'] > 0: online_node_exists = True #At-least one node (itself) should be online, if not some problem assert online_node_exists, "At least one node (current node) must be online" if parent_node_name: #parent node exist self.updateParent(parent_node_name) for node in self.multinode_data['known_nodes']: print node['name'] + ': ' + str( self.multinode_status[node['name']]['health']) if self.is_parent: #if this is a parent node, update the node_info table if self.curcon is None: #if no database connection exists make connection self.curcon = db_connection() tbl_node_info = settings.DATABASES['default']['TABLE_node_info'] self.curcon.execute('select node_id from ' + tbl_node_info) to_be_deleted_node_ids = self.curcon.fetchall() for index, node in enumerate(self.multinode_data['known_nodes']): if (index, ) in to_be_deleted_node_ids: to_be_deleted_node_ids.remove( (index, )) #don't remove this current node result = self.curcon.execute( 'select * from ' + tbl_node_info + ' where node_id=%s', (index, )) node_type = 'parent' if self.multinode_status[ node['name']]['health'] == 2 else "child" node_status = "ONLINE" if self.multinode_status[ node['name']]['health'] > 0 else "OFFLINE" ip_address = self.extract_ip(node['address']) last_scanned_time = self.multinode_status[ node['name']]['last_online_time'] last_offline_time = self.multinode_status[ node['name']]['last_offline_time'] last_sync_time = self.multinode_status[ node['name']]['last_sync_time'] var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)" value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)" actual_values_list = (index, node['name'], node_type, node_status, ip_address, last_scanned_time, last_offline_time, last_sync_time) if self.curcon.rowcount == 0: self.curcon.execute( "insert into " + tbl_node_info + " " + var_list + " VALUES" + value_placeholder_list, actual_values_list) else: self.curcon.execute( "update " + tbl_node_info + " SET " + var_list + " = " + value_placeholder_list + " where node_id = %s", actual_values_list + (index, )) self.curcon.commit() for id in to_be_deleted_node_ids: self.curcon.execute( 'delete from accounts_userprofile_nodes where nodeinfo_id=%s', id) #delete entries in user-profile for the old node self.curcon.commit() self.curcon.execute('delete from ' + tbl_node_info + ' where node_id=%s', id) #delete the old nodes self.curcon.commit() if self.recently_online_node_list: #Online nodes should be handled first because, the same node can first be #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after #coming online self.handle_online_nodes(self.recently_online_node_list) self.recently_online_node_list = [] # reset after handling if self.recently_offline_node_list: self.handle_offline_nodes(self.recently_offline_node_list) self.recently_offline_node_list = [] # reset after handling def connect_client(self, node): server_public_file = os.path.join(self.public_keys_dir, node['server_key']) server_public, _ = zmq.auth.load_certificate(server_public_file) # The client must know the server's public key to make a CURVE connection. self.client.curve_serverkey = server_public self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/') self.client.setsockopt(zmq.SUBSCRIBE, self.node_name) self.client.connect(node['address']) def disconnect_client(self, node): self.client.disconnect(node['address']) @Core.receiver('onstart') def onstart(self, sender, **kwargs): self.check_if_parent() print "Starting to receive Heart-beat" self.vip.heartbeat.start_with_period(15) client = self.ctx.socket(zmq.SUB) # We need two certificates, one for the client and one for # the server. The client must know the server's public key # to make a CURVE connection. client_secret_key_filename = self.multinode_data['known_nodes'][ self.node_index]['client_secret_key'] client_secret_file = os.path.join(self.secret_keys_dir, client_secret_key_filename) client_public, client_secret = zmq.auth.load_certificate( client_secret_file) client.curve_secretkey = client_secret client.curve_publickey = client_public self.client = client for node in self.multinode_data['known_nodes']: self.connect_client(node) print "Starting to listen" try: while True: #read messages if client.poll(1000): topic, msg = dejsonify(client.recv()) topic_list = topic.split('/') if topic_list[0] == 'heartbeat': node_name = topic_list[1] is_parent = topic_list[2] last_sync_with_parent = topic_list[3] if self.multinode_status[node_name][ 'health'] < 0: #the node health was <0 , means offline print node_name + " is back online" self.recently_online_node_list += [node_name] if is_parent.lower() in ['false', '0']: self.multinode_status[node_name]['health'] = 1 elif is_parent.lower() in ['true', '1']: self.multinode_status[node_name]['health'] = 2 self.parent_node = node_name else: raise ValueError( 'Invalid is_parent string in heart-beat message' ) self.multinode_status[node_name][ 'last_online_time'] = datetime.now() if topic_list[0] == self.node_name: #message addressed to this node if topic_list[1] == 'republish': new_topic = '/'.join( topic_list[2:] + ['repub-by-' + self.node_name, 'republished']) self.vip.pubsub.publish('pubsub', new_topic, None, msg) print self.node_name + ": " + topic, str(msg) else: gevent.sleep(2) except Exception as er: print "error" print er # stop auth thread self.auth.stop() @PubSub.subscribe('pubsub', 'to/multinodeagent/') def updateMultinodeData(self, peer, sender, bus, topic, headers, message): print "Updating Multinode data" topic_list = topic.split('/') self.configure_authenticator() #to/multinodeagent/from/<doesn't matter>/update_multinode_data if topic_list[4] == 'update_multinode_data': old_multinode_data = self.multinode_data self.getMultinodeData() for node in self.multinode_data['known_nodes']: if node not in old_multinode_data['known_nodes']: print "New node has been added to the cluster: " + node[ 'name'] print "We will connect to it" self.connect_client(node) for node in old_multinode_data['known_nodes']: if node not in self.multinode_data['known_nodes']: print "Node has been removed from the cluster: " + node[ 'name'] print "We will disconnect from it" self.disconnect_client(node) # TODO: remove it from the node_info table print "yay! got it" @PubSub.subscribe('pubsub', 'to/') def relayToMessage(self, peer, sender, bus, topic, headers, message): print topic topic_list = topic.split('/') #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui> to_index = topic_list.index('to') + 1 if 'from' in topic_list: from_index = topic_list.index('from') + 1 from_entity = topic_list[from_index] to_entity = topic_list[to_index] last_field = topic_list[-1] if last_field == 'republished': #it is already a republished message, no need to republish return if to_entity in settings.SYSTEM_AGENTS: self.disperseMessage(topic, headers, message) #republish to all nodes elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS: if not self.is_parent: self.republishToParent(topic, headers, message) else: self.curcon.execute( "SELECT current_node_id FROM " + node_devices_table + " WHERE agent_id=%s", (to_entity, )) if self.curcon.rowcount: node_id = self.curcon.fetchone()[0] if node_id != self.node_index: self.server.send( jsonify( self.getNodeName(node_id) + '/republish/' + topic, message)) @PubSub.subscribe('pubsub', 'from/') def relayFromMessage(self, peer, sender, bus, topic, headers, message): topic_list = topic.split('/') #from/<some_agent_or_ui>/topic from_entity = topic_list[1] last_field = topic_list[-1] if last_field == 'republished': #it is a republished message, no need to publish return self.disperseMessage(topic, headers, message) #republish to all nodes @PubSub.subscribe('pubsub', '') def on_match(self, peer, sender, bus, topic, headers, message): '''Use match_all to receive all messages and print them out.''' if sender == 'pubsub.compat': message = compat.unpack_legacy_message(headers, message)
class Command(LAVADaemonCommand): """ worker_host is the hostname of the worker this field is set by the admin and could therefore be empty in a misconfigured instance. """ logger = None help = "LAVA dispatcher master" default_logfile = "/var/log/lava-server/lava-master.log" def __init__(self, *args, **options): super().__init__(*args, **options) self.auth = None self.controler = None self.event_socket = None self.poller = None self.pipe_r = None self.inotify_fd = None # List of logs # List of known dispatchers. At startup do not load this from the # database. This will help to know if the slave as restarted or not. self.dispatchers = { "lava-logs": SlaveDispatcher("lava-logs", online=False) } self.events = {"canceling": set(), "available_dt": set()} def add_arguments(self, parser): super().add_arguments(parser) net = parser.add_argument_group("network") net.add_argument( '--master-socket', default='tcp://*:5556', help="Socket for master-slave communication. Default: tcp://*:5556" ) net.add_argument('--event-url', default="tcp://localhost:5500", help="URL of the publisher") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument( '--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def send_status(self, hostname): """ The master crashed, send a STATUS message to get the current state of jobs """ jobs = TestJob.objects.filter( actual_device__worker_host__hostname=hostname, state=TestJob.STATE_RUNNING) for job in jobs: self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname, job.actual_device.hostname) send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)]) def dispatcher_alive(self, hostname): if hostname not in self.dispatchers: # The server crashed: send a STATUS message self.logger.warning("Unknown dispatcher <%s> (server crashed)", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) self.send_status(hostname) # Mark the dispatcher as alive self.dispatchers[hostname].alive() def controler_socket(self): try: # We need here to use the zmq.NOBLOCK flag, otherwise we could block # the whole main loop where this function is called. msg = self.controler.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False # This is way to verbose for production and should only be activated # by (and for) developers # self.logger.debug("[CC] Receiving: %s", msg) # 1: the hostname (see ZMQ documentation) hostname = u(msg[0]) # 2: the action action = u(msg[1]) # Check that lava-logs only send PINGs if hostname == "lava-logs" and action != "PING": self.logger.error("%s => %s Invalid action from log daemon", hostname, action) return True # Handle the actions if action == 'HELLO' or action == 'HELLO_RETRY': self._handle_hello(hostname, action, msg) elif action == 'PING': self._handle_ping(hostname, action, msg) elif action == 'END': self._handle_end(hostname, action, msg) elif action == 'START_OK': self._handle_start_ok(hostname, action, msg) else: self.logger.error("<%s> sent unknown action=%s, args=(%s)", hostname, action, msg[1:]) return True def read_event_socket(self): try: msg = self.event_socket.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False try: (topic, _, dt, username, data) = (u(m) for m in msg) data = simplejson.loads(data) except ValueError: self.logger.error("Invalid event: %s", msg) return True if topic.endswith(".testjob"): if data["state"] == "Canceling": self.events["canceling"].add(int(data["job"])) elif data["state"] == "Submitted": if "device_type" in data: self.events["available_dt"].add(data["device_type"]) elif topic.endswith(".device"): if data["state"] == "Idle" and data["health"] in [ "Good", "Unknown", "Looping" ]: self.events["available_dt"].add(data["device_type"]) return True def _handle_end(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) error_msg = msg[3] compressed_description = msg[4] except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) # ACK even if the job is unknown to let the dispatcher # forget about it send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) return filename = os.path.join(job.output_dir, 'description.yaml') # If description.yaml already exists: a END was already received if os.path.exists(filename): self.logger.info("[%d] %s => END (duplicated), skipping", job_id, hostname) else: if compressed_description: self.logger.info("[%d] %s => END", job_id, hostname) else: self.logger.info( "[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)", job_id, hostname) with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(TestJob.HEALTH_INCOMPLETE) if error_msg: self.logger.error("[%d] Error: %s", job_id, error_msg) job.failure_comment = error_msg job.save() # Create description.yaml even if it's empty # Allows to know when END messages are duplicated try: # Create the directory if it was not already created mkdir(os.path.dirname(filename)) # TODO: check that compressed_description is not "" description = lzma.decompress(compressed_description) with open(filename, 'w') as f_description: f_description.write(description.decode("utf-8")) if description: parse_job_description(job) except (OSError, lzma.LZMAError) as exc: self.logger.error("[%d] Unable to dump 'description.yaml'", job_id) self.logger.exception("[%d] %s", job_id, exc) # ACK the job and mark the dispatcher as alive send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) self.dispatcher_alive(hostname) def _handle_hello(self, hostname, action, msg): # Check the protocol version try: slave_version = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("%s => %s", hostname, action) if slave_version != PROTOCOL_VERSION: self.logger.error( "<%s> using protocol v%d while master is using v%d", hostname, slave_version, PROTOCOL_VERSION) return send_multipart_u(self.controler, [hostname, 'HELLO_OK']) # If the dispatcher is known and sent an HELLO, means that # the slave has restarted if hostname in self.dispatchers: if action == 'HELLO': self.logger.warning("Dispatcher <%s> has RESTARTED", hostname) else: # Assume the HELLO command was received, and the # action succeeded. self.logger.warning("Dispatcher <%s> was not confirmed", hostname) else: # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO # message. self.logger.warning("New dispatcher <%s>", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) # Mark the dispatcher as alive self.dispatcher_alive(hostname) def _handle_ping(self, hostname, action, msg): # pylint: disable=unused-argument self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL) # Send back a signal send_multipart_u( self.controler, [hostname, 'PONG', str(PING_INTERVAL)]) self.dispatcher_alive(hostname) def _handle_start_ok(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("[%d] %s => START_OK", job_id, hostname) try: with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_running() job.save() except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) else: self.dispatcher_alive(hostname) def export_definition(self, job): # pylint: disable=no-self-use job_def = yaml.safe_load(job.definition) job_def['compatibility'] = job.pipeline_compatibility # no need for the dispatcher to retain comments return yaml.dump(job_def) def save_job_config(self, job, device_cfg, env_str, env_dut_str, dispatcher_cfg): output_dir = job.output_dir mkdir(output_dir) with open(os.path.join(output_dir, "job.yaml"), "w") as f_out: f_out.write(self.export_definition(job)) with open(os.path.join(output_dir, "device.yaml"), "w") as f_out: yaml.dump(device_cfg, f_out) if env_str: with open(os.path.join(output_dir, "env.yaml"), "w") as f_out: f_out.write(env_str) if env_dut_str: with open(os.path.join(output_dir, "env.dut.yaml"), "w") as f_out: f_out.write(env_dut_str) if dispatcher_cfg: with open(os.path.join(output_dir, "dispatcher.yaml"), "w") as f_out: f_out.write(dispatcher_cfg) def start_job(self, job): # Load job definition to get the variables for template # rendering job_def = yaml.safe_load(job.definition) job_ctx = job_def.get('context', {}) device = job.actual_device worker = device.worker_host # TODO: check that device_cfg is not None! device_cfg = device.load_configuration(job_ctx) # Try to load the dispatcher specific files and then fallback to the # default configuration files. env_str = load_optional_yaml_file( os.path.join(DISPATCHERS_PATH, worker.hostname, "env.yaml"), ENV_PATH) env_dut_str = load_optional_yaml_file( os.path.join(DISPATCHERS_PATH, worker.hostname, "env.dut.yaml"), ENV_DUT_PATH) dispatcher_cfg = load_optional_yaml_file( os.path.join(DISPATCHERS_PATH, worker.hostname, "dispatcher.yaml"), os.path.join(DISPATCHERS_PATH, "%s.yaml" % worker.hostname)) self.save_job_config(job, device_cfg, env_str, env_dut_str, dispatcher_cfg) self.logger.info("[%d] START => %s (%s)", job.id, worker.hostname, device.hostname) send_multipart_u(self.controler, [ worker.hostname, 'START', str(job.id), self.export_definition(job), yaml.dump(device_cfg), dispatcher_cfg, env_str, env_dut_str ]) # For multinode jobs, start the dynamic connections parent = job for sub_job in job.sub_jobs_list: if sub_job == parent or not sub_job.dynamic_connection: continue # inherit only enough configuration for dynamic_connection operation self.logger.info( "[%d] Trimming dynamic connection device configuration.", sub_job.id) min_device_cfg = parent.actual_device.minimise_configuration( device_cfg) self.save_job_config(sub_job, min_device_cfg, env_str, env_dut_str, dispatcher_cfg) self.logger.info("[%d] START => %s (connection)", sub_job.id, worker.hostname) send_multipart_u(self.controler, [ worker.hostname, 'START', str(sub_job.id), self.export_definition(sub_job), yaml.dump(min_device_cfg), dispatcher_cfg, env_str, env_dut_str ]) def start_jobs(self, jobs=None): """ Loop on all scheduled jobs and send the START message to the slave. """ # make the request atomic query = TestJob.objects.select_for_update() # Only select test job that are ready query = query.filter(state=TestJob.STATE_SCHEDULED) # Only start jobs on online workers query = query.filter( actual_device__worker_host__state=Worker.STATE_ONLINE) # exclude test job without a device: they are special test jobs like # dynamic connection. query = query.exclude(actual_device=None) # Allow for partial scheduling if jobs is not None: query = query.filter(id__in=jobs) # Loop on all jobs for job in query: msg = None try: self.start_job(job) except jinja2.TemplateNotFound as exc: self.logger.error("[%d] Template not found: '%s'", job.id, exc.message) msg = "Template not found: '%s'" % exc.message except jinja2.TemplateSyntaxError as exc: self.logger.error( "[%d] Template syntax error in '%s', line %d: %s", job.id, exc.name, exc.lineno, exc.message) msg = "Template syntax error in '%s', line %d: %s" % ( exc.name, exc.lineno, exc.message) except OSError as exc: self.logger.error("[%d] Unable to read '%s': %s", job.id, exc.filename, exc.strerror) msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror) except yaml.YAMLError as exc: self.logger.error("[%d] Unable to parse job definition: %s", job.id, exc) msg = "Cannot parse job definition: %s" % exc if msg: # Add the error as lava.job result metadata = { "case": "job", "definition": "lava", "error_type": "Infrastructure", "error_msg": msg, "result": "fail" } suite, _ = TestSuite.objects.get_or_create(name="lava", job=job) TestCase.objects.create(name="job", suite=suite, result=TestCase.RESULT_FAIL, metadata=yaml.dump(metadata)) job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True) job.save() def cancel_jobs(self, partial=False): # make the request atomic query = TestJob.objects.select_for_update() # Only select the test job that are canceling query = query.filter(state=TestJob.STATE_CANCELING) # Only cancel jobs on online workers query = query.filter( actual_device__worker_host__state=Worker.STATE_ONLINE) # Allow for partial canceling if partial: query = query.filter(id__in=list(self.events["canceling"])) # Loop on all jobs for job in query: worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host self.logger.info("[%d] CANCEL => %s", job.id, worker.hostname) send_multipart_u(self.controler, [worker.hostname, 'CANCEL', str(job.id)]) def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-master", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return filename = os.path.join(settings.MEDIA_ROOT, 'lava-master-config.yaml') self.logger.debug("[INIT] Dumping config to %s", filename) with open(filename, 'w') as output: yaml.dump(options, output) self.logger.info("[INIT] Marking all workers as offline") with transaction.atomic(): for worker in Worker.objects.select_for_update().all(): worker.go_state_offline() worker.save() # Create the sockets context = zmq.Context() self.controler = context.socket(zmq.ROUTER) self.event_socket = context.socket(zmq.SUB) if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.controler.setsockopt(zmq.IPV6, 1) self.event_socket.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate( options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except OSError as err: self.logger.error(err) self.auth.stop() return self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_server = True self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.controler.setsockopt(zmq.IDENTITY, b"master") # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42 # "If two clients use the same identity when connecting to a ROUTER # [...] the ROUTER socket shall hand-over the connection to the new # client and disconnect the existing one." self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1) self.controler.bind(options['master_socket']) self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC)) self.event_socket.connect(options['event_url']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.controler, zmq.POLLIN) self.poller.register(self.event_socket, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] LAVA master has started.") self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION) try: self.main_loop(options) except BaseException as exc: self.logger.error("[CLOSE] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Drop controler socket: the protocol does handle lost messages self.logger.info( "[CLOSE] Closing the controler socket and dropping messages") self.controler.close(linger=0) self.event_socket.close(linger=0) if options['encrypt']: self.auth.stop() context.term() def main_loop(self, options): last_schedule = last_dispatcher_check = time.time() while True: try: try: # Compute the timeout now = time.time() timeout = min( SCHEDULE_INTERVAL - (now - last_schedule), PING_INTERVAL - (now - last_dispatcher_check)) # If some actions are remaining, decrease the timeout if any([self.events[k] for k in self.events.keys()]): timeout = min(timeout, 2) # Wait at least for 1ms timeout = max(timeout * 1000, 1) # Wait for data or a timeout sockets = dict(self.poller.poll(timeout)) except zmq.error.ZMQError: continue if sockets.get(self.pipe_r) == zmq.POLLIN: self.logger.info("[POLL] Received a signal, leaving") break # Command socket if sockets.get(self.controler) == zmq.POLLIN: while self.controler_socket( ): # Unqueue all pending messages pass # Events socket if sockets.get(self.event_socket) == zmq.POLLIN: while self.read_event_socket( ): # Unqueue all pending messages pass # Wait for the next iteration to handle the event. # In fact, the code that generated the event (lava-logs or # lava-server-gunicorn) needs some time to commit the # database transaction. # If we are too fast, the database object won't be # available (or in the right state) yet. continue # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) # Check dispatchers status now = time.time() if now - last_dispatcher_check > PING_INTERVAL: for hostname, dispatcher in self.dispatchers.items(): if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT: if hostname == "lava-logs": self.logger.error( "[STATE] lava-logs goes OFFLINE") else: self.logger.error( "[STATE] Dispatcher <%s> goes OFFLINE", hostname) self.dispatchers[hostname].go_offline() last_dispatcher_check = now # Limit accesses to the database. This will also limit the rate of # CANCEL and START messages if time.time() - last_schedule > SCHEDULE_INTERVAL: if self.dispatchers["lava-logs"].online: schedule(self.logger) # Dispatch scheduled jobs with transaction.atomic(): self.start_jobs() else: self.logger.warning( "lava-logs is offline: can't schedule jobs") # Handle canceling jobs with transaction.atomic(): self.cancel_jobs() # Do not count the time taken to schedule jobs last_schedule = time.time() else: # Cancel the jobs and remove the jobs from the set if self.events["canceling"]: with transaction.atomic(): self.cancel_jobs(partial=True) self.events["canceling"] = set() # Schedule for available device-types if self.events["available_dt"]: jobs = schedule(self.logger, self.events["available_dt"]) self.events["available_dt"] = set() # Dispatch scheduled jobs with transaction.atomic(): self.start_jobs(jobs) except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset.") # Closing the database connection will force Django to reopen # the connection connection.close() time.sleep(2)
class StupidNode: pubkey = privkey = None channel = "" # subscription filter or something (I think) PORTS = 4 # as we add or remove ports, make sure this is the number of ports a StupidNode uses def __init__(self, endpoint="*", identity=None, keyring=DEFAULT_KEYRING): self.keyring = keyring self.endpoint = (endpoint if isinstance(endpoint, Endpoint) else Endpoint(endpoint)) self.endpoints = list() self.identity = identity or f"{gethostname()}-{self.endpoint.pub}" self.log = logging.getLogger(f"{self.identity}") self.log.debug("begin node setup / creating context") self.ctx = zmq.Context() self.cleartext_ctx = zmq.Context() self.start_auth() self.log.debug("creating sockets") self.pub = self.mk_socket(zmq.PUB) self.router = self.mk_socket(zmq.ROUTER) self.router.router_mandatory = ( 1 # one of the few opts that can be set after bind() ) self.rep = self.mk_socket(zmq.REP, enable_curve=False) self.sub = list() self.dealer = list() self.log.debug("binding sockets") self.bind(self.pub) self.bind(self.router) self.bind(self.rep, enable_curve=False) self.log.debug("registering polling") self.poller = zmq.Poller() self.poller.register(self.router, zmq.POLLIN) self.log.debug("configuring interrupt signal") signal.signal(signal.SIGINT, self.interrupt) self.log.debug("configuring WAI Reply Thread") self._who_are_you_thread = Thread( target=self.who_are_you_reply_machine) self._who_are_you_continue = True self._who_are_you_thread.start() self.route_queue = deque(list(), ROUTE_QUEUE_LEN) self.routes = dict() self.log.debug("node setup complete") def who_are_you_reply_machine(self): while self._who_are_you_continue: if self.rep.poll(200): self.log.debug("wai polled, trying to recv") msg = self.rep.recv() ttype = zmq_socket_type_name(self.rep) self.log.debug('received "%s" over %s socket', msg, ttype) msg = [self.identity.encode(), self.pubkey] self.log.debug('sending "%s" as reply over %s socket', msg, ttype) self.rep.send_multipart(msg) self.log.debug("wai thread seems finished, loop broken") def start_auth(self): self.log.debug("starting auth thread") self.auth = ThreadAuthenticator(self.ctx) self.auth.start() self.auth.allow("127.0.0.1") self.auth.configure_curve(domain="*", location=self.keyring) self.load_or_create_key() @property def key_basename(self): return scrub_identity_name_for_certfile(self.identity) @property def key_filename(self): return os.path.join(self.keyring, self.key_basename + ".key") @property def secret_key_filename(self): return self.key_filename + "_secret" def load_key(self): self.log.debug("loading node key-pair") self.pubkey, self.privkey = zmq.auth.load_certificate( self.secret_key_filename) def load_or_create_key(self): try: self.load_key() except IOError as e: self.log.debug("error loading key: %s", e) self.log.debug("creating node key-pair") os.makedirs(self.keyring, mode=0o0700, exist_ok=True) zmq.auth.create_certificates(self.keyring, self.key_basename) self.load_key() def preprocess_message(self, msg, msg_class=TaggedMessage): if not isinstance(msg, msg_class): if not isinstance(msg, (list, tuple)): msg = (msg, ) msg = msg_class(*msg, name=self.identity) rmsg = repr(msg) emsg = msg.encode() return msg, rmsg, emsg def route_failed(self, msg): if not isinstance(msg, RoutedMessage): raise TypeError("msg must already be a RoutedMessage") msg.failures += 1 if msg.failures <= 5: self.log.debug("(re)queueing %s for later delivery", repr(msg)) if len(self.route_queue) == self.route_queue.maxlen: self.log.error("route_queue full, discarding %s", repr(self.route_queue[0])) self.route_queue.append(msg) else: self.log.error("discarding %s after %d failures", repr(msg), msg.failures) def route_message(self, to, msg): if isinstance(to, StupidNode): to = to.identity if isinstance(to, (list, tuple)): to = to[-1] R = self.routes.get(to) if R: to = (R[0], to) if isinstance(msg, RoutedMessage): msg.to = to else: # preprocess passes *msg to msg_class() -- ie, RoutedMessage(to, *msg) if isinstance(msg, list): msg = tuple(msg) elif not isinstance(msg, tuple): msg = (msg, ) msg = (to, ) + msg tmsg, rmsg, emsg = self.preprocess_message(msg, msg_class=RoutedMessage) self.log.debug("routing message %s -- encoding: %s", rmsg, emsg) try: self.router.send_multipart(emsg) except zmq.error.ZMQError as zmq_e: self.log.debug("route to %s failed: %s", to, zmq_e) if "Host unreachable" not in str(zmq_e): raise self.route_failed(tmsg) def deal_message(self, msg): self.log.debug( "dealing message (actually publishing with no_publish=True)") self.publish_message(msg, no_publish=True) def publish_message(self, msg, no_deal=False, no_deal_to=None, no_publish=False): tmsg, rmsg, emsg = self.preprocess_message(msg) self.log.debug( "publishing message %s no_publish=%s, no_deal=%s, no_deal_to=%s", rmsg, no_publish, no_deal, no_deal_to, ) self.local_workflow(tmsg) if not no_publish: self.pub.send_multipart(emsg) if no_deal: return if no_deal_to is None: ok_send = lambda x: True elif callable(no_deal_to): ok_send = no_deal_to elif isinstance(no_deal_to, zmq.Socket): npt_i = self.dealer.index(no_deal_to) ok_send = lambda x: x != npt_i elif isinstance(no_deal_to, int): ok_send = lambda x: x != no_deal_to elif isinstance(no_deal_to, (list, tuple)): ok_send = lambda x: x not in no_deal_to for i, sock in enumerate(self.dealer): if ok_send(i): self.log.debug("dealing message %s to %s", rmsg, self.endpoints[i]) sock.send_multipart(emsg) else: self.log.debug("not sending %s to %s", rmsg, self.endpoints[i]) def mk_socket(self, stype, enable_curve=True): # defaults: # socket.setsockopt(zmq.LINGER, -1) # infinite # socket.setsockopt(zmq.IDENTITY, None) # socket.setsockopt(zmq.TCP_KEEPALIVE, -1) # socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, -1) # socket.setsockopt(zmq.TCP_KEEPALIVE_CNT, -1) # socket.setsockopt(zmq.TCP_KEEPALIVE_IDLE, -1) # socket.setsockopt(zmq.RECONNECT_IVL, 100) # socket.setsockopt(zmq.RECONNECT_IVL_MAX, 0) # 0 := always use IVL # the above can be accessed as attributes instead (they are case # insensitive, we choose lower case below so it looks like boring # python) if enable_curve: socket = self.ctx.socket(stype) self.log.debug("create %s socket in crypto context", zmq_socket_type_name(stype)) else: socket = self.cleartext_ctx.socket(stype) self.log.debug("create %s socket in cleartext context", zmq_socket_type_name(stype)) socket.linger = 1 socket.identity = self.identity.encode() socket.reconnect_ivl = 1000 socket.reconnect_ivl_max = 10000 if enable_curve: socket.curve_secretkey = self.privkey socket.curve_publickey = self.pubkey return socket def local_workflow(self, msg): self.log.debug("start local_workflow %s", repr(msg)) msg = self.local_react(msg) if msg: msg = self.all_react(msg) return msg def sub_workflow(self, socket): idx = self.sub.index(socket) enp = self.endpoints[idx] msg = self.sub_receive(socket, idx) self.log.debug("start sub_workflow (idx=%d -> endpoint=%s) %s", idx, enp, repr(msg)) for react in (self.sub_react, self.nonlocal_react, self.all_react): if msg: msg = react(msg, idx=idx) self.log.debug("end sub_workflow") return msg def router_workflow(self): msg = self.router_receive() self.log.debug("start router_workflow %s", repr(msg)) for react in (self.router_react, self.nonlocal_react, self.all_react): if not msg: break msg = react(msg) self.log.debug("end router_workflow") return msg def dealer_workflow(self, socket): idx = self.dealer.index(socket) enp = self.endpoints[idx] msg = self.dealer_receive(socket, idx) self.log.debug("start deal_workflow (idx=%d -> endpoint=%s) %s", idx, enp, repr(msg)) for react in (self.dealer_react, self.nonlocal_react, self.all_react): if not msg: break msg = react(msg, idx=idx) self.log.debug("end deal_workflow") return msg def sub_receive(self, socket, idx): # pylint: disable=unused-argument return TaggedMessage(*socket.recv_multipart()) def dealer_receive(self, socket, idx): # pylint: disable=unused-argument msg = socket.recv_multipart() rm = RoutedMessage.decode(msg) if rm: return rm # dealer's always receive a routed message if it doesn't appear to be # routed, then it's simply intended for us. In that case, build a # tagged message and mark it as non-publish msg = TaggedMessage(*msg) msg.publish_mark = False return msg def router_receive(self): # we ignore the source ID (in '_') and just believe the msg.tag.name ... it's # roughly the same thing anyway _, *msg = self.router.recv_multipart() rm = RoutedMessage.decode(msg) if rm: return rm return TaggedMessage(*msg) def all_react(self, msg, idx=None): # pylint: disable=unused-argument return msg def sub_react(self, msg, idx=None): # pylint: disable=unused-argument return msg def dealer_react(self, msg, idx=None): # pylint: disable=unused-argument return msg def router_react(self, msg): return msg def nonlocal_react(self, msg, idx=None): if isinstance(msg, RoutedMessage): msg = self.routed_react(msg, idx=idx) return msg def local_react(self, msg): return msg def routed_react(self, msg, idx=None): # pylint: disable=unused-argument return False def poll(self, timeo=500, other_cb=None): """Check to see if there's any incoming messages. If anything seems ready to receive, invoke the related workflow or invoke other_cb (if given) on the socket item. """ items = dict(self.poller.poll(timeo)) ret = list() for item in items: if items[item] != zmq.POLLIN: continue if item in self.sub: res = self.sub_workflow(item) elif item in self.dealer: res = self.dealer_workflow(item) elif item is self.router: res = self.router_workflow() elif callable(other_cb): res = other_cb(item) else: res = None if False and isinstance(item, zmq.Socket): self.log.error( "no workflow defined for socket of type %s -- received: %s", zmq_socket_type_name(item), item.recv_multipart(), ) else: self.log.error( "no workflow defined for socket of type %s -- regarding as fatal", zmq_socket_type_name(item), ) # note: this normally doesn't trigger an exit... thanks threading raise Exception("unhandled poll item") if isinstance(res, TaggedMessage): ret.append(res) return ret def interrupt(self, signo, eframe): # pylint: disable=unused-argument print(" kaboom") self.closekill() sys.exit(0) def closekill(self): if hasattr(self, "auth") and self.auth is not None: if self.auth.is_alive(): self.log.debug("trying to stop auth thread") self.auth.stop() self.log.debug("auth thread seems to have stopped") del self.auth if hasattr(self, "_who_are_you_thread"): if self._who_are_you_thread.is_alive(): self.log.debug("WAI Thread seems to be alive, trying to join") self._who_are_you_continue = False self._who_are_you_thread.join() self.log.debug("WAI Thread seems to jave joined us.") del self._who_are_you_thread if hasattr(self, "cleartext_ctx"): self.log.debug("destroying cleartext context") self.cleartext_ctx.destroy(1) del self.cleartext_ctx if hasattr(self, "ctx"): self.log.debug("destroying crypto context") self.ctx.destroy(1) del self.ctx def __del__(self): self.log.debug("%s is being deleted", self) self.closekill() def bind(self, socket, enable_curve=True): if enable_curve: socket.curve_server = True # must come before bind try: f = self.endpoint.format(socket.type) socket.bind(f) except zmq.ZMQError as e: raise zmq.ZMQError(f"unable to bind {f}: {e}") from e def who_are_you_request(self, endpoint): req = self.mk_socket(zmq.REQ, enable_curve=False) req.connect(endpoint.format(zmq.REQ)) msg = b"Who are you?" self.log.debug("sending cleartext request: %s", msg) req.send(msg) self.log.debug("waiting for reply") res = req.recv_multipart() self.log.debug("received reply: %s", res) if len(res) == 2: return res req.close() return None, None def pubkey_pathname(self, node_id): if isinstance(node_id, Endpoint): node_id = Endpoint.host fname = scrub_identity_name_for_certfile(node_id) + ".key" pname = os.path.join(self.keyring, fname) return pname def learn_or_load_endpoint_pubkey(self, endpoint): epubk_pname = self.pubkey_pathname(endpoint) if not os.path.isfile(epubk_pname): self.log.debug( "%s does not exist yet, trying to learn certificate", epubk_pname) node_id, public_key = self.who_are_you_request(endpoint) if node_id: endpoint.identity = node_id.decode() epubk_pname = self.pubkey_pathname(node_id) if not os.path.isfile(epubk_pname): with open(epubk_pname, "wb") as fh: fh.write( b"# generated via rep/req pubkey transfer\n\n") fh.write(b"metadata\n") # NOTE: in zmq/auth/certs.py's _write_key_file, # metadata should be key-value pairs; roughly like the # following (although with their particular py2/py3 # nerosis edited out): # # f.write('metadata\n') # for k,v in metadata.items(): # f.write(f" {k} = {v}\n") fh.write(b"curve\n") fh.write(b' public-key = "') fh.write(public_key) fh.write(b'"') self.log.debug("loading certificate %s", epubk_pname) ret, _ = zmq.auth.load_certificate(epubk_pname) return ret def connect_to_endpoints(self, *endpoints): self.log.debug("connecting remote endpoints") for item in endpoints: self.connect_to_endpoint(item) self.log.debug("remote endpoints connected") return self def _create_connected_socket(self, endpoint, stype, pubkey, preconnect=None): self.log.debug("creating %s socket to endpoint=%s", zmq_socket_type_name(stype), endpoint) s = self.mk_socket(stype) s.curve_serverkey = pubkey if callable(preconnect): preconnect(s) s.connect(endpoint.format(stype)) return s def connect_to_endpoint(self, endpoint): if isinstance(endpoint, StupidNode): endpoint = endpoint.endpoint elif not isinstance(endpoint, Endpoint): endpoint = Endpoint(endpoint) self.log.debug("learning or loading endpoint=%s pubkey", endpoint) epk = self.learn_or_load_endpoint_pubkey(endpoint) sos = lambda s: s.setsockopt_string(zmq.SUBSCRIBE, self.channel) sub = self._create_connected_socket(endpoint, zmq.SUB, epk, sos) self.poller.register(sub, zmq.POLLIN) self.sub.append(sub) deal = self._create_connected_socket(endpoint, zmq.DEALER, epk) self.poller.register(deal, zmq.POLLIN) self.dealer.append(deal) self.endpoints.append(endpoint) return self def __repr__(self): return f"{self.__class__.__name__}({self.identity})"