def _run(self): """ Start a loop to process the ZMQ requests from the signaler client. """ logger.debug("Running SignalerQt loop") context = zmq.Context() socket = context.socket(zmq.REP) # Start an authenticator for this context. auth = ThreadAuthenticator(context) auth.start() auth.allow('127.0.0.1') # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) public, secret = get_frontend_certificates() socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_server = True # must come before bind socket.bind(self.BIND_ADDR) while self._do_work.is_set(): # Wait for next request from client try: request = socket.recv(zmq.NOBLOCK) # logger.debug("Received request: '{0}'".format(request)) socket.send("OK") self._process_request(request) except zmq.ZMQError as e: if e.errno != zmq.EAGAIN: raise time.sleep(0.01) logger.debug("SignalerQt thread stopped.")
def _init_zmq(self): """ Configure the zmq components and connection. """ context = zmq.Context() socket = context.socket(zmq.REP) if flags.ZMQ_HAS_CURVE: # Start an authenticator for this context. auth = ThreadAuthenticator(context) auth.start() # XXX do not hardcode this here. auth.allow('127.0.0.1') # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) public, secret = get_backend_certificates() socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_server = True # must come before bind socket.bind(self.BIND_ADDR) if not flags.ZMQ_HAS_CURVE: os.chmod(self.SOCKET_FILE, 0600) self._zmq_socket = socket
class ZMQPull(ZMQ): classname = "ZMQPull" def __init__(self, name, options, inbound): super().__init__(name, options, inbound) self.socket_type = zmq.PULL def secure_setup(self): # Load certificates # TODO: handle errors self.auth = ThreadAuthenticator(self.context) self.auth.start() self.LOG.debug("Server keys in %s", self.secure_config["self"]) sock_pub, sock_priv = zmq.auth.load_certificate(self.secure_config["self"]) if self.secure_config.get("clients", None) is not None: self.LOG.debug("Client certificates in %s", self.secure_config["clients"]) self.auth.configure_curve(domain="*", location=self.secure_config["clients"]) else: self.LOG.debug("Every clients can connect") self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) # Setup the socket self.sock.curve_publickey = sock_pub self.sock.curve_secretkey = sock_priv self.sock.curve_server = True
class ZMQPull(ZMQ): classname = "ZMQPull" def __init__(self, name, options, inbound): super().__init__(name, options, inbound) self.socket_type = zmq.PULL def secure_setup(self): # Load certificates # TODO: handle errors self.auth = ThreadAuthenticator(self.context) self.auth.start() self.LOG.debug("Server keys in %s", self.secure_config["self"]) sock_pub, sock_priv = load_certificate(self.secure_config["self"]) if self.secure_config.get("clients", None) is not None: self.LOG.debug("Client certificates in %s", self.secure_config["clients"]) self.auth.configure_curve(domain="*", location=self.secure_config["clients"]) else: self.LOG.debug("Every clients can connect") self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) # Setup the socket self.sock.curve_publickey = sock_pub self.sock.curve_secretkey = sock_priv self.sock.curve_server = True
class Authenticator(object): _authenticators = {} @classmethod def instance(cls, public_keys_dir): '''Please avoid create multi instance''' if public_keys_dir in cls._authenticators: return cls._authenticators[public_keys_dir] new_instance = cls(public_keys_dir) cls._authenticators[public_keys_dir] = new_instance return new_instance def __init__(self, public_keys_dir): self._auth = ThreadAuthenticator(zmq.Context.instance()) self._auth.start() self._auth.allow('*') self._auth.configure_curve(domain='*', location=public_keys_dir) def set_server_key(self, zmq_socket, server_secret_key_path): '''must call before bind''' load_and_set_key(zmq_socket, server_secret_key_path) zmq_socket.curve_server = True def set_client_key(self, zmq_socket, client_secret_key_path, server_public_key_path): '''must call before bind''' load_and_set_key(zmq_socket, client_secret_key_path) server_public, _ = zmq.auth.load_certificate(server_public_key_path) zmq_socket.curve_serverkey = server_public def stop(self): self._auth.stop()
def _run(self): """ Start a loop to process the ZMQ requests from the signaler client. """ logger.debug("Running SignalerQt loop") context = zmq.Context() socket = context.socket(zmq.REP) # Start an authenticator for this context. auth = ThreadAuthenticator(context) auth.start() auth.allow('127.0.0.1') # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) public, secret = get_frontend_certificates() socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_server = True # must come before bind socket.bind(self.BIND_ADDR) while self._do_work.is_set(): # Wait for next request from client try: request = socket.recv(zmq.NOBLOCK) logger.debug("Received request: '{0}'".format(request)) socket.send("OK") self._process_request(request) except zmq.ZMQError as e: if e.errno != zmq.EAGAIN: raise time.sleep(0.01) logger.debug("SignalerQt thread stopped.")
def _init_txzmq(self): """ Configure the txzmq components and connection. """ self._zmq_factory = txzmq.ZmqFactory() self._zmq_factory.registerForShutdown() self._zmq_connection = txzmq.ZmqREPConnection(self._zmq_factory) context = self._zmq_factory.context socket = self._zmq_connection.socket def _gotMessage(messageId, messageParts): self._zmq_connection.reply(messageId, "OK") self._process_request(messageParts) self._zmq_connection.gotMessage = _gotMessage if flags.ZMQ_HAS_CURVE: # Start an authenticator for this context. auth = ThreadAuthenticator(context) auth.start() # XXX do not hardcode this here. auth.allow('127.0.0.1') # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) public, secret = get_backend_certificates() socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_server = True # must come before bind proto, addr = self._server_address.split('://') # tcp/ipc, ip/socket socket.bind(self._server_address) if proto == 'ipc': os.chmod(addr, 0600)
def main(): auth = ThreadAuthenticator(zmq.Context.instance()) auth.start() auth.allow('127.0.0.1') # Tell the authenticator how to handle CURVE requests auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) key = Key.load('example/broker.key_secret') broker = SecureMajorDomoBroker(key, sys.argv[1]) try: broker.serve_forever() except KeyboardInterrupt: auth.stop() raise
def main(): auth = ThreadAuthenticator(zmq.Context.instance()) auth.start() auth.allow('127.0.0.1') # Tell the authenticator how to handle CURVE requests auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) key = Key.load('example/broker.key_secret') broker = SecureMajorDomoBroker(key, sys.argv[1]) try: broker.serve_forever() except KeyboardInterrupt: auth.stop() raise
def _start_thread_auth(self, socket): """ Start the zmq curve thread authenticator. :param socket: The socket in which to configure the authenticator. :type socket: zmq.Socket """ authenticator = ThreadAuthenticator(self._factory.context) authenticator.start() # XXX do not hardcode this here. authenticator.allow('127.0.0.1') # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) authenticator.configure_curve(domain="*", location=public_keys_dir) socket.curve_server = True # must come before bind
def _start_thread_auth(self, socket): """ Start the zmq curve thread authenticator. :param socket: The socket in which to configure the authenticator. :type socket: zmq.Socket """ authenticator = ThreadAuthenticator(self._factory.context) authenticator.start() # XXX do not hardcode this here. authenticator.allow('127.0.0.1') # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) authenticator.configure_curve(domain="*", location=public_keys_dir) socket.curve_server = True # must come before bind
def auth_init(): """Start an authenticator for this context.""" from zmq.auth.thread import ThreadAuthenticator from jomiel.log import lg auth = ThreadAuthenticator(ctx, log=lg()) auth.start() auth.allow(opts.curve_allow) # Tell the authenticator to use the client certificates in the # specified directory. # from os.path import abspath pubdir = abspath(opts.curve_public_key_dir) auth.configure_curve(domain=opts.curve_domain, location=pubdir) return auth
class ContextHandler(): def __init__(self, publicPath): self.__context = zmq.Context() self.publicPath = publicPath self.auth = ThreadAuthenticator(self.__context) self.auth.start() self.auth.configure_curve(domain='*', location=self.publicPath) self.auth.thread.setName("CurveAuth") def getContext(self): return self.__context def configureAuth(self): self.auth.configure_curve(domain='*', location=self.publicPath) def cleanup(self): self.__context.destroy()
class CurveAuthenticator(object): def __init__(self, ctx, domain='*', location=zmq.auth.CURVE_ALLOW_ANY, callback=None): self._domain = domain self._location = location self._callback = callback self._ctx = ctx self._atx = ThreadAuthenticator(self.ctx) self._atx.start() if (self._callback is not None): logging.info('Callback: {0}'.format(self._callback)) self._atx.configure_curve_callback( '*', credentials_provider=self._callback) elif (self._location == zmq.auth.CURVE_ALLOW_ANY or self._location is None): self._atx.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) else: self.load_certs() @property def atx(self): return self._atx @property def location(self): return self._location @property def domain(self): return self._domain @property def ctx(self): return self._ctx def load_certs(self): self.atx.configure_curve(domain=self._domain, location=self._location)
def _start_thread_auth(self, socket): """ Start the zmq curve thread authenticator. :param socket: The socket in which to configure the authenticator. :type socket: zmq.Socket """ authenticator = ThreadAuthenticator(self._factory.context) # Temporary fix until we understand what the problem is # See https://leap.se/code/issues/7536 time.sleep(0.5) authenticator.start() # XXX do not hardcode this here. authenticator.allow('127.0.0.1') # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) authenticator.configure_curve(domain="*", location=public_keys_dir) socket.curve_server = True # must come before bind
def _init_zmq(self): """ Configure the zmq components and connection. """ context = zmq.Context() socket = context.socket(zmq.REP) # Start an authenticator for this context. auth = ThreadAuthenticator(context) auth.start() auth.allow('127.0.0.1') # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) public, secret = get_backend_certificates() socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_server = True # must come before bind socket.bind(self.BIND_ADDR) self._zmq_socket = socket
def setup_auth(): global _auth assert _options is not None auth = _options.get('auth',None) if auth is None: return base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..')) try: _auth = ThreadAuthenticator(_zctx) _auth.start() whitelist = auth.get('whitelist',None) if whitelist is not None: _auth.allow(whitelist) public_path = auth.get('public_key_dir','public_keys') _auth.configure_curve(domain='*',location=getExistsPath(base_dir,public_path)) private_dir = getExistsPath(base_dir,auth.get('private_key_dir','private_keys')) private_key = os.path.join(private_dir,auth.get('private_key_file','server.key_secret')) server_public,server_private = zmq.auth.load_certificate(private_key) _sock.curve_secretkey = server_private _sock.curve_publickey = server_public _sock.curve_server = True except: _auth.stop() _auth = None
class Device(Actor): ''' The actor class implements all the management and control functions over its components ''' def __init__(self, gModel, gModelName, dName, qName, sysArgv): ''' Constructor :param dName: device type name :type dName: str :param qName: qualified name of the device instance: 'actor.inst' :type qName: str ''' self.logger = logging.getLogger(__name__) self.inst_ = self self.appName = gModel["name"] self.modelName = gModelName aName,iName = qName.split('.') self.name = qName self.iName = iName self.dName = dName self.pid = os.getpid() self.uuid = None self.suffix = "" self.setupIfaces() # Assumption : pid is a 4 byte int self.actorID = ipaddress.IPv4Address(self.globalHost).packed + self.pid.to_bytes(4, 'big') if dName not in gModel["devices"]: raise BuildError('Device "%s" unknown' % dName) # In order to make the rest of the code work, we build an actor model for the device devModel = gModel["devices"][dName] self.model = {} # The made-up actor model formals = devModel["formals"] # Formals are the same as those of the device (component) self.model["formals"] = formals devInst = { "type": dName } # There is a single instance, containing the device component actuals = [] for arg in formals: name = arg["name"] actual = {} actual["name"] = name actual["param"] = name actuals.append(actual) devInst["actuals"] = actuals self.model["instances"] = { iName: devInst} # Single instance (under iName) aModel = gModel["actors"][aName] self.model["locals"] = aModel["locals"] # Locals self.model["internals"] = aModel["internals"] # Internals self.INT_RE = re.compile(r"^[-]?\d+$") self.parseParams(sysArgv) # Use czmq's context czmq_ctx = Zsys.init() self.context = zmq.Context.shadow(czmq_ctx.value) Zsys.handler_reset() # Reset previous signal # Context for app sockets self.appContext = zmq.Context() if Config.SECURITY: (self.public_key, self.private_key) = zmq.auth.load_certificate(const.appCertFile) _public = zmq.curve_public(self.private_key) if(self.public_key != _public): self.logger.error("bad security key(s)") raise BuildError("invalid security key(s)") hosts = ['127.0.0.1'] try: with open(const.appDescFile, 'r') as f: content = yaml.load(f, Loader=yaml.Loader) hosts += content.hosts except: self.logger.error("Error loading app descriptor:s", str(sys.exc_info()[1])) self.auth = ThreadAuthenticator(self.appContext) self.auth.start() self.auth.allow(*hosts) self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) else: (self.public_key, self.private_key) = (None, None) self.auth = None self.appContext = self.context try: if os.path.isfile(const.logConfFile) and os.access(const.logConfFile, os.R_OK): spdlog_setup.from_file(const.logConfFile) except Exception as e: self.logger.error("error while configuring componentLogger: %s" % repr(e)) messages = gModel["messages"] # Global message types (global on the network) self.messageNames = [] for messageSpec in messages: self.messageNames.append(messageSpec["name"]) locals_ = self.model["locals"] # Local message types (local to the host) self.localNames = [] for messageSpec in locals_: self.localNames.append(messageSpec["type"]) internals = self.model["internals"] # Internal message types (internal to the actor process) self.internalNames = [] for messageSpec in internals: self.internalNames.append(messageSpec["type"]) groups = gModel["groups"] self.groupTypes = {} for group in groups: self.groupTypes[group["name"]] = { "kind": group["kind"], "message": group["message"], "timed": group["timed"] } self.components = {} instSpecs = self.model["instances"] _compSpecs = gModel["components"] devSpecs = gModel["devices"] for instName in instSpecs: # Create the component instances: the 'parts' instSpec = instSpecs[instName] instType = instSpec['type'] if instType in devSpecs: typeSpec = devSpecs[instType] else: raise BuildError('Device type "%s" for instance "%s" is undefined' % (instType, instName)) instFormals = typeSpec['formals'] instActuals = instSpec['actuals'] instArgs = self.buildInstArgs(instName, instFormals, instActuals) # Check whether the component is C++ component ccComponentFile = 'lib' + instType.lower() + '.so' ccComp = os.path.isfile(ccComponentFile) try: if ccComp: modObj = importlib.import_module('lib' + instType.lower()) self.components[instName] = modObj.create_component_py(self, self.model, typeSpec, instName, instType, instArgs, self.appName, self.name, groups) else: self.components[instName] = Part(self, typeSpec, instName, instType, instArgs) except Exception as e: traceback.print_exc() self.logger.error("Error while constructing part '%s.%s': %s" % (instType, instName, str(e))) def getPortMessageTypes(self, ports, key, kinds, res): for _name, spec in ports[key].items(): for kind in kinds: typeName = spec[kind] res.append({"type": typeName}) def getMessageTypes(self, devModel): res = [] ports = devModel["ports"] self.getPortMessageTypes(ports, "pubs", ["type"], res) self.getPortMessageTypes(ports, "subs", ["type"], res) self.getPortMessageTypes(ports, "reqs", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "reps", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "clts", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "srvs", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "qrys", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "anss", ["req_type", "rep_type"], res) return res def isDevice(self): return True def setup(self): ''' Perform a setup operation on the actor (after the initial construction but before the activation of parts) ''' self.logger.info("setup") # self.setupIfaces() self.suffix = self.macAddress self.disco = DiscoClient(self, self.suffix) self.disco.start() # Start the discovery service client self.disco.registerActor() # Register this actor with the discovery service self.logger.info("device registered with disco") self.deplc = DeplClient(self, self.suffix) self.deplc.start() ok = self.deplc.registerActor() self.logger.info("device %s registered with depl" % ("is" if ok else "is not")) self.controls = { } self.controlMap = { } for inst in self.components: comp = self.components[inst] control = self.context.socket(zmq.PAIR) control.bind('inproc://part_' + inst + '_control') self.controls[inst] = control self.controlMap[id(control)] = comp if isinstance(comp, Part): self.components[inst].setup(control) else: self.components[inst].setup() def terminate(self): self.logger.info("terminating") for component in self.components.values(): component.terminate() # self.devc.terminate() self.disco.terminate() # Clean up everything # self.context.destroy() time.sleep(1.0) self.logger.info("terminated") os._exit(0)
class MultiNodeAgent(BEMOSSAgent): def __init__(self, *args, **kwargs): super(MultiNodeAgent, self).__init__(*args, **kwargs) self.multinode_status = dict() self.getMultinodeData() self.agent_id = 'multinodeagent' self.is_parent = False self.last_sync_with_parent = datetime(1991, 1, 1) #equivalent to -ve infinitive self.parent_node = None self.recently_online_node_list = [] # initialize to lists to empty self.recently_offline_node_list = [ ] # they will be filled as nodes are discovered to be online/offline self.setup() self.runPeriodically(self.send_heartbeat, 20) self.runPeriodically(self.check_health, 60, start_immediately=False) self.runPeriodically(self.sync_all_with_parent, 600) self.subscribe('relay_message', self.relayDirectMessage) self.subscribe('update_multinode_data', self.updateMultinodeData) self.runContinuously(self.pollClients) self.run() def getMultinodeData(self): self.multinode_data = db_helper.get_multinode_data() self.nodelist_dict = { node['name']: node for node in self.multinode_data['known_nodes'] } self.node_name_list = [ node['name'] for node in self.multinode_data['known_nodes'] ] self.address_list = [ node['address'] for node in self.multinode_data['known_nodes'] ] self.server_key_list = [ node['server_key'] for node in self.multinode_data['known_nodes'] ] self.node_name = self.multinode_data['this_node'] for index, node in enumerate(self.multinode_data['known_nodes']): if node['name'] == self.node_name: self.node_index = index break else: raise ValueError( '"this_node:" entry on the multinode_data json file is invalid' ) for node_name in self.node_name_list: #initialize all nodes data if node_name not in self.multinode_status: #initialize new nodes. There could be already the node if this getMultiNode # data is being called later self.multinode_status[node_name] = dict() self.multinode_status[node_name][ 'health'] = -10 #initialized; never online/offline self.multinode_status[node_name]['last_sync_time'] = datetime( 1991, 1, 1) self.multinode_status[node_name]['last_online_time'] = None self.multinode_status[node_name]['last_offline_time'] = None self.multinode_status[node_name]['last_scanned_time'] = None def setup(self): print "Setup" base_dir = settings.PROJECT_DIR + "/" public_keys_dir = os.path.abspath(os.path.join(base_dir, 'public_keys')) secret_keys_dir = os.path.abspath( os.path.join(base_dir, 'private_keys')) self.secret_keys_dir = secret_keys_dir self.public_keys_dir = public_keys_dir if not (os.path.exists(public_keys_dir) and os.path.exists(secret_keys_dir)): logging.critical( "Certificates are missing - run generate_certificates.py script first" ) sys.exit(1) ctx = zmq.Context.instance() self.ctx = ctx # Start an authenticator for this context. self.auth = ThreadAuthenticator(ctx) self.auth.start() self.configure_authenticator() server = ctx.socket(zmq.PUB) server_secret_key_filename = self.multinode_data['known_nodes'][ self.node_index]['server_secret_key'] server_secret_file = os.path.join(secret_keys_dir, server_secret_key_filename) server_public, server_secret = zmq.auth.load_certificate( server_secret_file) server.curve_secretkey = server_secret server.curve_publickey = server_public server.curve_server = True # must come before bind server.bind( self.multinode_data['known_nodes'][self.node_index]['address']) self.server = server self.configureClient() def configure_authenticator(self): self.auth.allow() # Tell authenticator to use the certificate in a directory self.auth.configure_curve(domain='*', location=self.public_keys_dir) def disperseMessage(self, sender, topic, message): for node_name in self.node_name_list: if node_name == self.node_name: continue self.server.send( jsonify(sender, node_name + '/republish/' + topic, message)) def republishToParent(self, sender, topic, message): if self.is_parent: return #if I am parent, the message is already published for node_name in self.node_name_list: if self.multinode_status[node_name][ 'health'] == 2: #health = 2 is the parent node self.server.send( jsonify(sender, node_name + '/republish/' + topic, message)) def sync_node_with_parent(self, node_name): if self.is_parent: print "Syncing " + node_name self.last_sync_with_parent = datetime.now() sync_date_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path) # with open(self.self_database_dump_path, 'r') as f: # file_content = f.read() # msg = {'database_dump': base64.b64encode(file_content)} self.server.send( jsonify( self.node_name, node_name + '/sync-with-parent/' + sync_date_string + '/' + self.node_name, "")) def sync_all_with_parent(self, dbcon): if self.is_parent: self.last_sync_with_parent = datetime.now() sync_date_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') print "Syncing all nodes" for node_name in self.node_name_list: if node_name == self.node_name: continue # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path) # with open(self.self_database_dump_path, 'r') as f: # file_content = f.read() # msg = {'database_dump': base64.b64encode(file_content)} self.server.send( jsonify( self.node_name, node_name + '/sync-with-parent/' + sync_date_string + '/' + self.node_name, "")) def send_heartbeat(self, dbcon): #self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'}) #print 'publishing' print "Sending heartbeat" last_sync_string = self.last_sync_with_parent.strftime( '%B %d, %Y, %H:%M:%S') self.server.send( jsonify( self.node_name, 'heartbeat/' + self.node_name + '/' + str(self.is_parent) + '/' + last_sync_string, "")) def extract_ip(self, addr): return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})', addr).groups()[0] def getNodeId(self, node_name): for index, node in enumerate(self.multinode_data['known_nodes']): if node['name'] == node_name: node_index = index break else: raise ValueError('the node name: ' + node_name + ' is not found in multinode data') return node_index def getNodeName(self, node_id): return self.multinode_data['known_nodes'][node_id]['name'] def handle_offline_nodes(self, dbcon, node_name_list): if self.is_parent: # start all the agents belonging to that node on this node command_group = [] for node_name in node_name_list: node_id = self.getNodeId(node_name) #put the offline event into cassandra events log table, and also create notification self.EventRegister(dbcon, 'node-offline', reason='communication-error', source=node_name) # get a list of agents that were supposedly running in that offline node dbcon.execute( "SELECT agent_id FROM " + node_devices_table + " WHERE assigned_node_id=%s", (node_id, )) if dbcon.rowcount: agent_ids = dbcon.fetchall() for agent_id in agent_ids: message = dict() message[STATUS_CHANGE.AGENT_ID] = agent_id[0] message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE.AGENT_STATUS] = 'start' message[ STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY command_group += [message] dbcon.execute( "UPDATE " + node_devices_table + " SET current_node_id=(%s), date_move=(%s)" " WHERE agent_id=(%s)", (self.node_index, datetime.now( pytz.UTC), agent_id[0])) dbcon.commit() print "moving agents from offline node to parent: " + str( node_name_list) print command_group if command_group: self.bemoss_publish(target='networkagent', topic='status_change', message=command_group) def handle_online_nodes(self, dbcon, node_name_list): if self.is_parent: # start all the agents belonging to that nodes back on them command_group = [] for node_name in node_name_list: node_id = self.getNodeId(node_name) if self.node_index == node_id: continue #don't handle self-online self.EventRegister(dbcon, 'node-online', reason='communication-restored', source=node_name) #get a list of agents that were supposed to be running in that online node dbcon.execute( "SELECT agent_id FROM " + node_devices_table + " WHERE assigned_node_id=%s", (node_id, )) if dbcon.rowcount: agent_ids = dbcon.fetchall() for agent_id in agent_ids: message = dict() message[STATUS_CHANGE.AGENT_ID] = agent_id[0] message[ STATUS_CHANGE. NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT message[STATUS_CHANGE.NODE] = str(self.node_index) message[STATUS_CHANGE. AGENT_STATUS] = 'stop' #stop in this node command_group += [message] message = dict(message) #create another copy message[STATUS_CHANGE.NODE] = str(node_id) message[ STATUS_CHANGE. AGENT_STATUS] = 'start' #start in the target node command_group += [message] #immediately update the multnode device assignment table dbcon.execute( "UPDATE " + node_devices_table + " SET current_node_id=(%s), date_move=(%s)" " WHERE agent_id=(%s)", (node_id, datetime.now(pytz.UTC), agent_id[0])) dbcon.commit() print "Moving agents back to the online node: " + str( node_name_list) print command_group if command_group: self.bemoss_publish(target='networkagent', topic='status_change', message=command_group) def updateParent(self, dbcon, parent_node_name): parent_ip = self.extract_ip( self.nodelist_dict[parent_node_name]['address']) write_new = False if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE ): # but parent file doesn't exists write_new = True else: with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f: read_ip = f.read() if read_ip != parent_ip: write_new = True if write_new: with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f: f.write(parent_ip) if dbcon: dbcon.close() #close old connection dbcon = db_helper.db_connection( ) #start new connection using new parent_ip self.updateMultinodeData(sender=self.name, topic='update_parent', message="") def check_health(self, dbcon): for node_name, node in self.multinode_status.items(): if node['health'] > 0: #initialize all online nodes to 0. If they are really online, they should change it # back to 1 or 2 (parent) within 30 seconds throught the heartbeat. node['health'] = 0 time.sleep(30) parent_node_name = None #initialize parent node online_node_exists = False for node_name, node in self.multinode_status.items(): node['last_scanned_time'] = datetime.now() if node['health'] == 0: node['health'] = -1 node['last_offline_time'] = datetime.now() self.recently_offline_node_list += [node_name] elif node['health'] == -1: #offline since long pass elif node[ 'health'] == -10: #The node was initialized to -10, and never came online. Treat it as recently going # offline for this iteration so that the agents that were supposed to be running there can be migrated node['health'] = -1 self.recently_offline_node_list += [node_name] elif node['health'] == 2: #there is some parent node present parent_node_name = node_name if node['health'] > 0: online_node_exists = True #At-least one node (itself) should be online, if not some problem assert online_node_exists, "At least one node (current node) must be online" if not parent_node_name: #parent node doesn't exist #find a suitable node to elect a parent. The node with latest update from previous parent wins. If there is #tie, then the node coming earlier in the node-list on multinode data wins online_node_last_sync = dict( ) #only the online nodes, and their last-sync-times for node_name, node in self.multinode_status.items( ): #copy only the online nodes if node['health'] > 0: online_node_last_sync[node_name] = node['last_sync_time'] latest_node = max(online_node_last_sync, key=online_node_last_sync.get) latest_sync_date = online_node_last_sync[latest_node] for node_name in self.node_name_list: if self.multinode_status[node_name][ 'health'] <= 0: #dead nodes can't be parents continue if self.multinode_status[node_name][ 'last_sync_time'] == latest_sync_date: # this is the first node with the latest update from parent #elligible parent found self.updateParent(dbcon, node_name) if node_name == self.node_name: # I am the node, so I get to become the parent self.is_parent = True print "I am the boss now, " + self.node_name break else: #I-am-not-the-first-node with latest update; somebody else is self.is_parent = False break else: #parent node exist self.updateParent(dbcon, parent_node_name) for node in self.multinode_data['known_nodes']: print node['name'] + ': ' + str( self.multinode_status[node['name']]['health']) if self.is_parent: #if this is a parent node, update the node_info table if dbcon is None: #if no database connection exists make connection dbcon = db_helper.db_connection() tbl_node_info = settings.DATABASES['default']['TABLE_node_info'] dbcon.execute('select node_id from ' + tbl_node_info) to_be_deleted_node_ids = dbcon.fetchall() for index, node in enumerate(self.multinode_data['known_nodes']): if (index, ) in to_be_deleted_node_ids: to_be_deleted_node_ids.remove( (index, )) #don't remove this current node result = dbcon.execute( 'select * from ' + tbl_node_info + ' where node_id=%s', (index, )) node_type = 'parent' if self.multinode_status[ node['name']]['health'] == 2 else "child" node_status = "ONLINE" if self.multinode_status[ node['name']]['health'] > 0 else "OFFLINE" ip_address = self.extract_ip(node['address']) last_scanned_time = self.multinode_status[ node['name']]['last_online_time'] last_offline_time = self.multinode_status[ node['name']]['last_offline_time'] last_sync_time = self.multinode_status[ node['name']]['last_sync_time'] var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)" value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)" actual_values_list = (index, node['name'], node_type, node_status, ip_address, last_scanned_time, last_offline_time, last_sync_time) if dbcon.rowcount == 0: dbcon.execute( "insert into " + tbl_node_info + " " + var_list + " VALUES" + value_placeholder_list, actual_values_list) else: dbcon.execute( "update " + tbl_node_info + " SET " + var_list + " = " + value_placeholder_list + " where node_id = %s", actual_values_list + (index, )) dbcon.commit() for id in to_be_deleted_node_ids: dbcon.execute( 'delete from accounts_userprofile_nodes where nodeinfo_id=%s', id) #delete entries in user-profile for the old node dbcon.commit() dbcon.execute('delete from ' + tbl_node_info + ' where node_id=%s', id) #delete the old nodes dbcon.commit() if self.recently_online_node_list: #Online nodes should be handled first because, the same node can first be #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after #coming online self.handle_online_nodes(dbcon, self.recently_online_node_list) self.recently_online_node_list = [] # reset after handling if self.recently_offline_node_list: self.handle_offline_nodes(dbcon, self.recently_offline_node_list) self.recently_offline_node_list = [] # reset after handling def connect_client(self, node): server_public_file = os.path.join(self.public_keys_dir, node['server_key']) server_public, _ = zmq.auth.load_certificate(server_public_file) # The client must know the server's public key to make a CURVE connection. self.client.curve_serverkey = server_public self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/') self.client.setsockopt(zmq.SUBSCRIBE, self.node_name) self.client.connect(node['address']) def disconnect_client(self, node): self.client.disconnect(node['address']) def configureClient(self): print "Starting to receive Heart-beat" client = self.ctx.socket(zmq.SUB) # We need two certificates, one for the client and one for # the server. The client must know the server's public key # to make a CURVE connection. client_secret_key_filename = self.multinode_data['known_nodes'][ self.node_index]['client_secret_key'] client_secret_file = os.path.join(self.secret_keys_dir, client_secret_key_filename) client_public, client_secret = zmq.auth.load_certificate( client_secret_file) client.curve_secretkey = client_secret client.curve_publickey = client_public self.client = client for node in self.multinode_data['known_nodes']: self.connect_client(node) def pollClients(self, dbcon): if self.client.poll(1000): sender, topic, msg = dejsonify(self.client.recv()) topic_list = topic.split('/') if topic_list[0] == 'heartbeat': node_name = sender is_parent = topic_list[2] last_sync_with_parent = topic_list[3] if self.multinode_status[node_name][ 'health'] < 0: #the node health was <0 , means offline print node_name + " is back online" self.recently_online_node_list += [node_name] self.sync_node_with_parent(node_name) if is_parent.lower() in ['false', '0']: self.multinode_status[node_name]['health'] = 1 elif is_parent.lower() in ['true', '1']: self.multinode_status[node_name]['health'] = 2 self.parent_node = node_name else: raise ValueError( 'Invalid is_parent string in heart-beat message') self.multinode_status[node_name][ 'last_online_time'] = datetime.now() self.multinode_status[node_name][ 'last_sync_time'] = datetime.strptime( last_sync_with_parent, '%B %d, %Y, %H:%M:%S') if topic_list[0] == self.node_name: if topic_list[1] == 'sync-with-parent': pass # print topic # self.last_sync_with_parent = datetime.strptime(topic_list[2], '%B %d, %Y, %H:%M:%S') # content = base64.b64decode(msg['database_dump']) # newpath = 'bemossdb.sql' # with open(newpath, 'w') as f: # f.write(content) # try: # os.system( # 'psql -c "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid();"') # os.system( # 'dropdb bemossdb') # This step requires all connections to be closed # os.system('createdb bemossdb -O admin') # dump_result = subprocess.check_output('psql bemossdb < ' + newpath, shell=True) # except Exception as er: # print "Couldn't sync database with parent because of error: " # print er # # parent_node_name = topic_list[3] # self.updateParent(parent_node_name) if topic_list[1] == 'republish': target = msg['target'] actual_message = msg['actual_message'] actual_topic = msg['actual_topic'] self.bemoss_publish(target=target, topic=actual_topic + '/republished', message=actual_message, sender=sender) print self.node_name + ": " + topic, str(msg) else: time.sleep(2) def cleanup(self): # stop auth thread self.auth.stop() def updateMultinodeData(self, dbcon, sender, topic, message): print "Updating Multinode data" topic_list = topic.split('/') self.configure_authenticator() #to/multinodeagent/from/<doesn't matter>/update_multinode_data if topic_list[4] == 'update_multinode_data': old_multinode_data = self.multinode_data self.getMultinodeData() for node in self.multinode_data['known_nodes']: if node not in old_multinode_data['known_nodes']: print "New node has been added to the cluster: " + node[ 'name'] print "We will connect to it" self.connect_client(node) for node in old_multinode_data['known_nodes']: if node not in self.multinode_data['known_nodes']: print "Node has been removed from the cluster: " + node[ 'name'] print "We will disconnect from it" self.disconnect_client(node) # TODO: remove it from the node_info table print "yay! got it" def relayDirectMessage(self, dbcon, sender, topic, message): print topic #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui> from_entity = sender target = message['target'] actual_message = message['actual_message'] actual_topic = message['actual_topic'] for to_entity in target: if to_entity in settings.NO_FORWARD_AGENTS: return #no forwarding should be done for these agents elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS: if not self.is_parent: self.republishToParent(sender, topic, message) elif to_entity == "ALL": self.disperseMessage(sender, topic=topic, message=message) else: dbcon.execute( "SELECT current_node_id FROM " + node_devices_table + " WHERE agent_id=%s", (to_entity, )) if dbcon.rowcount: node_id = dbcon.fetchone()[0] if node_id != self.node_index: self.server.send( jsonify( sender, self.getNodeName(node_id) + '/republish/' + topic, message)) else: self.disperseMessage( sender, topic, message ) #republish to all nodes if we don't know where to send
class Driver(drivers.BaseDriver): def __init__( self, args, encrypted_traffic_data=None, interface=None, ): """Initialize the Driver. :param args: Arguments parsed by argparse. :type args: Object :param encrypted_traffic: Enable|Disable encrypted traffic. :type encrypted_traffic: Boolean :param interface: The interface instance (client/server) :type interface: Object """ self.thread_processor = multiprocessing.Process self.event = multiprocessing.Event() self.semaphore = multiprocessing.Semaphore self.flushqueue = _FlushQueue self.args = args if getattr(self.args, "zmq_generate_keys", False) is True: self._generate_certificates() print("New certificates generated") raise SystemExit(0) self.encrypted_traffic_data = encrypted_traffic_data mode = getattr(self.args, "mode", None) if mode == "client": self.bind_address = self.args.zmq_server_address elif mode == "server": self.bind_address = self.args.zmq_bind_address else: self.bind_address = "*" self.proto = "tcp" self.connection_string = "{proto}://{addr}".format( proto=self.proto, addr=self.bind_address) if self.encrypted_traffic_data: self.encrypted_traffic = self.encrypted_traffic_data.get("enabled") self.secret_keys_dir = self.encrypted_traffic_data.get( "secret_keys_dir") self.public_keys_dir = self.encrypted_traffic_data.get( "public_keys_dir") else: self.encrypted_traffic = False self.secret_keys_dir = None self.public_keys_dir = None self._context = zmq.Context() self.ctx = self._context.instance() self.poller = zmq.Poller() self.interface = interface super(Driver, self).__init__( args=args, encrypted_traffic_data=self.encrypted_traffic_data, interface=interface, ) self.bind_job = None self.bind_backend = None self.hwm = getattr(self.args, "zmq_highwater_mark", 1024) def __copy__(self): """Return a new copy of the driver.""" return Driver( args=self.args, encrypted_traffic_data=self.encrypted_traffic_data, interface=self.interface, ) def _backend_bind(self): """Bind an address to a backend socket and return the socket. :returns: Object """ bind = self._socket_bind( socket_type=zmq.ROUTER, connection=self.connection_string, port=self.args.backend_port, ) bind.set_hwm(self.hwm) self.log.debug( "Identity [ %s ] backend connect hwm state [ %s ]", self.identity, bind.get_hwm(), ) return bind def _backend_connect(self): """Connect to a backend socket and return the socket. :returns: Object """ self.log.debug("Establishing backend connection.") bind = self._socket_connect( socket_type=zmq.DEALER, connection=self.connection_string, port=self.args.backend_port, ) bind.set_hwm(self.hwm) self.log.debug( "Identity [ %s ] backend connect hwm state [ %s ]", self.identity, bind.get_hwm(), ) return bind def _bind_check(self, bind, interval=1, constant=1000): """Return True if a bind type contains work ready. :param bind: A given Socket bind to identify. :type bind: Object :param interval: Exponential Interval used to determine the polling duration for a given socket. :type interval: Integer :param constant: Constant time used to poll for new jobs. :type constant: Integer :returns: Object """ socks = dict(self.poller.poll(interval * constant)) if socks.get(bind) == zmq.POLLIN: return True else: return False def _close(self, socket): if socket is None: return try: socket.close(linger=2) close_time = time.time() while not socket.closed: if time.time() - close_time > 60: raise TimeoutError( "Job [ {} ] failed to close transfer socket".format( self.job_id)) else: socket.close(linger=2) time.sleep(1) except Exception as e: self.log.error( "Ran into an exception while closing the socket %s", str(e), ) else: self.log.debug("Backend socket closed") def _generate_certificates(self, base_dir="/etc/directord"): """Generate client and server CURVE certificate files. :param base_dir: Directord configuration path. :type base_dir: String """ keys_dir = os.path.join(base_dir, "certificates") public_keys_dir = os.path.join(base_dir, "public_keys") secret_keys_dir = os.path.join(base_dir, "private_keys") for item in [keys_dir, public_keys_dir, secret_keys_dir]: os.makedirs(item, exist_ok=True) # Run certificate backup self._move_certificates(directory=public_keys_dir, backup=True) self._move_certificates(directory=secret_keys_dir, backup=True, suffix=".key_secret") # create new keys in certificates dir for item in ["server", "client"]: self._key_generate(keys_dir=keys_dir, key_type=item) # Move generated certificates in place self._move_certificates( directory=keys_dir, target_directory=public_keys_dir, suffix=".key", ) self._move_certificates( directory=keys_dir, target_directory=secret_keys_dir, suffix=".key_secret", ) def _job_bind(self): """Bind an address to a job socket and return the socket. :returns: Object """ return self._socket_bind( socket_type=zmq.ROUTER, connection=self.connection_string, port=self.args.job_port, ) def _job_connect(self): """Connect to a job socket and return the socket. :returns: Object """ self.log.debug("Establishing Job connection.") return self._socket_connect( socket_type=zmq.DEALER, connection=self.connection_string, port=self.args.job_port, ) def _key_generate(self, keys_dir, key_type): """Generate certificate. :param keys_dir: Full Directory path where a given key will be stored. :type keys_dir: String :param key_type: Key type to be generated. :type key_type: String """ zmq_auth.create_certificates(keys_dir, key_type) @staticmethod def _move_certificates(directory, target_directory=None, backup=False, suffix=".key"): """Move certificates when required. :param directory: Set the origin path. :type directory: String :param target_directory: Set the target path. :type target_directory: String :param backup: Enable file backup before moving. :type backup: Boolean :param suffix: Set the search suffix :type suffix: String """ for item in os.listdir(directory): if backup: target_file = "{}.bak".format(os.path.basename(item)) else: target_file = os.path.basename(item) if item.endswith(suffix): os.rename( os.path.join(directory, item), os.path.join(target_directory or directory, target_file), ) def _socket_bind(self, socket_type, connection, port, poller_type=None): """Return a socket object which has been bound to a given address. When the socket_type is not PUB or PUSH, the bound socket will also be registered with self.poller as defined within the Interface class. :param socket_type: Set the Socket type, typically defined using a ZeroMQ constant. :type socket_type: Integer :param connection: Set the Address information used for the bound socket. :type connection: String :param port: Define the port which the socket will be bound to. :type port: Integer :param poller_type: Set the Socket type, typically defined using a ZeroMQ constant. :type poller_type: Integer :returns: Object """ if poller_type is None: poller_type = zmq.POLLIN bind = self._socket_context(socket_type=socket_type) auth_enabled = (self.args.zmq_shared_key or self.args.zmq_curve_encryption) if auth_enabled: self.auth = ThreadAuthenticator(self.ctx, log=self.log) self.auth.start() self.auth.allow() if self.args.zmq_shared_key: # Enables basic auth self.auth.configure_plain( domain="*", passwords={"admin": self.args.zmq_shared_key}) bind.plain_server = True # Enable shared key authentication self.log.info("Shared key authentication enabled.") elif self.args.zmq_curve_encryption: server_secret_file = os.path.join(self.secret_keys_dir, "server.key_secret") for item in [ self.public_keys_dir, self.secret_keys_dir, server_secret_file, ]: if not os.path.exists(item): raise SystemExit( "The required path [ {} ] does not exist. Have" " you generated your keys?".format(item)) self.auth.configure_curve(domain="*", location=self.public_keys_dir) try: server_public, server_secret = zmq_auth.load_certificate( server_secret_file) except OSError as e: self.log.error( "Failed to load certificates: %s, Configuration: %s", str(e), vars(self.args), ) raise SystemExit("Failed to load certificates") else: bind.curve_secretkey = server_secret bind.curve_publickey = server_public bind.curve_server = True # Enable curve authentication bind.bind("{connection}:{port}".format( connection=connection, port=port, )) if socket_type not in [zmq.PUB]: self.poller.register(bind, poller_type) return bind def _socket_connect(self, socket_type, connection, port, poller_type=None): """Return a socket object which has been bound to a given address. > A connection back to the server will wait 10 seconds for an ack before going into a retry loop. This is done to forcefully cycle the connection object to reset. :param socket_type: Set the Socket type, typically defined using a ZeroMQ constant. :type socket_type: Integer :param connection: Set the Address information used for the bound socket. :type connection: String :param port: Define the port which the socket will be bound to. :type port: Integer :param poller_type: Set the Socket type, typically defined using a ZeroMQ constant. :type poller_type: Integer :returns: Object """ if poller_type is None: poller_type = zmq.POLLIN bind = self._socket_context(socket_type=socket_type) if self.args.zmq_shared_key: bind.plain_username = b"admin" # User is hard coded. bind.plain_password = self.args.zmq_shared_key.encode() self.log.info("Shared key authentication enabled.") elif self.args.zmq_curve_encryption: client_secret_file = os.path.join(self.secret_keys_dir, "client.key_secret") server_public_file = os.path.join(self.public_keys_dir, "server.key") for item in [ self.public_keys_dir, self.secret_keys_dir, client_secret_file, server_public_file, ]: if not os.path.exists(item): raise SystemExit( "The required path [ {} ] does not exist. Have" " you generated your keys?".format(item)) try: client_public, client_secret = zmq_auth.load_certificate( client_secret_file) server_public, _ = zmq_auth.load_certificate( server_public_file) except OSError as e: self.log.error( "Error while loading certificates: %s. Configuration: %s", str(e), vars(self.args), ) raise SystemExit("Failed to load keys.") else: bind.curve_secretkey = client_secret bind.curve_publickey = client_public bind.curve_serverkey = server_public if socket_type == zmq.SUB: bind.setsockopt_string(zmq.SUBSCRIBE, self.identity) else: bind.setsockopt_string(zmq.IDENTITY, self.identity) self.poller.register(bind, poller_type) bind.connect("{connection}:{port}".format( connection=connection, port=port, )) self.log.info("Socket connected to [ %s ].", connection) return bind def _socket_context(self, socket_type): """Create socket context and return a bind object. :param socket_type: Set the Socket type, typically defined using a ZeroMQ constant. :type socket_type: Integer :returns: Object """ bind = self.ctx.socket(socket_type) bind.linger = getattr(self.args, "heartbeat_interval", 60) hwm = int(self.hwm * 4) try: bind.sndhwm = bind.rcvhwm = hwm except AttributeError: bind.hwm = hwm bind.set_hwm(hwm) bind.setsockopt(zmq.SNDHWM, hwm) bind.setsockopt(zmq.RCVHWM, hwm) if socket_type == zmq.ROUTER: bind.setsockopt(zmq.ROUTER_MANDATORY, 1) return bind @staticmethod def _socket_recv(socket, nonblocking=False): """Receive a message over a ZM0 socket. The message specification for server is as follows. [ b"Identity" b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] The message specification for client is as follows. [ b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] All message parts are byte encoded. All possible control characters are defined within the Interface class. For more on control characters review the following URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl). :param socket: ZeroMQ socket object. :type socket: Object :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean """ if nonblocking: flags = zmq.NOBLOCK else: flags = 0 return socket.recv_multipart(flags=flags) @tenacity.retry( retry=tenacity.retry_if_exception_type(Exception), wait=tenacity.wait_fixed(5), before_sleep=tenacity.before_sleep_log( logger.getLogger(name="directord"), logging.WARN), ) def _socket_send( self, socket, identity=None, msg_id=None, control=None, command=None, data=None, info=None, stderr=None, stdout=None, nonblocking=False, ): """Send a message over a ZM0 socket. The message specification for server is as follows. [ b"Identity" b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] The message specification for client is as follows. [ b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] All message information is assumed to be byte encoded. All possible control characters are defined within the Interface class. For more on control characters review the following URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl). :param socket: ZeroMQ socket object. :type socket: Object :param identity: Target where message will be sent. :type identity: Bytes :param msg_id: ID information for a given message. If no ID is provided a UUID will be generated. :type msg_id: Bytes :param control: ASCII control charaters. :type control: Bytes :param command: Command definition for a given message. :type command: Bytes :param data: Encoded data that will be transmitted. :type data: Bytes :param info: Encoded information that will be transmitted. :type info: Bytes :param stderr: Encoded error information from a command. :type stderr: Bytes :param stdout: Encoded output information from a command. :type stdout: Bytes :param nonblocking: Enable non-blocking send. :type nonblocking: Boolean :returns: Object """ def _encoder(item): try: return item.encode() except AttributeError: return item if not msg_id: msg_id = utils.get_uuid() if not control: control = self.nullbyte if not command: command = self.nullbyte if not data: data = self.nullbyte if not info: info = self.nullbyte if not stderr: stderr = self.nullbyte if not stdout: stdout = self.nullbyte message_parts = [msg_id, control, command, data, info, stderr, stdout] if identity: message_parts.insert(0, identity) message_parts = [_encoder(i) for i in message_parts] if nonblocking: flags = zmq.NOBLOCK else: flags = 0 try: return socket.send_multipart(message_parts, flags=flags) except Exception as e: self.log.warn("Failed to send message to [ %s ]", identity) raise e def _recv(self, socket, nonblocking=False): """Receive message. :param socket: ZeroMQ socket object. :type socket: Object :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean :returns: Tuple """ recv_obj = self._socket_recv(socket=socket, nonblocking=nonblocking) return tuple([i.decode() for i in recv_obj]) def backend_recv(self, nonblocking=False): """Receive a transfer message. :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean :returns: Tuple """ return self._recv(socket=self.bind_backend, nonblocking=nonblocking) def backend_init(self): """Initialize the backend socket. For server mode, this is a bound local socket. For client mode, it is a connection to the server socket. :returns: Object """ if self.args.mode == "server": self.bind_backend = self._backend_bind() else: self.bind_backend = self._backend_connect() def backend_close(self): """Close the backend socket.""" self._close(socket=self.bind_backend) def backend_check(self, interval=1, constant=1000): """Return True if the backend contains work ready. :param bind: A given Socket bind to identify. :type bind: Object :param interval: Exponential Interval used to determine the polling duration for a given socket. :type interval: Integer :param constant: Constant time used to poll for new jobs. :type constant: Integer :returns: Object """ return self._bind_check(bind=self.bind_backend, interval=interval, constant=constant) def backend_send(self, *args, **kwargs): """Send a job message. * All args and kwargs are passed through to the socket send. :returns: Object """ kwargs["socket"] = self.bind_backend return self._socket_send(*args, **kwargs) @staticmethod def get_lock(): """Returns a thread lock.""" return multiprocessing.Lock() def heartbeat_send(self, host_uptime=None, agent_uptime=None, version=None, driver=None): """Send a heartbeat. :param host_uptime: Sender uptime :type host_uptime: String :param agent_uptime: Sender agent uptime :type agent_uptime: String :param version: Sender directord version :type version: String :param version: Driver information :type version: String """ job_id = utils.get_uuid() self.log.info( "Job [ %s ] sending heartbeat from [ %s ] to server", job_id, self.identity, ) return self.job_send( control=self.heartbeat_notice, msg_id=job_id, data=json.dumps({ "job_id": job_id, "version": version, "host_uptime": host_uptime, "agent_uptime": agent_uptime, "machine_id": self.machine_id, "driver": driver, }), ) def job_send(self, *args, **kwargs): """Send a job message. * All args and kwargs are passed through to the socket send. :returns: Object """ kwargs["socket"] = self.bind_job return self._socket_send(*args, **kwargs) def job_recv(self, nonblocking=False): """Receive a transfer message. :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean :returns: Tuple """ return self._recv(socket=self.bind_job, nonblocking=nonblocking) def job_init(self): """Initialize the job socket. For server mode, this is a bound local socket. For client mode, it is a connection to the server socket. :returns: Object """ if self.args.mode == "server": self.bind_job = self._job_bind() else: self.bind_job = self._job_connect() def job_close(self): """Close the job socket.""" self._close(socket=self.bind_job) def job_check(self, interval=1, constant=1000): """Return True if a job contains work ready. :param bind: A given Socket bind to identify. :type bind: Object :param interval: Exponential Interval used to determine the polling duration for a given socket. :type interval: Integer :param constant: Constant time used to poll for new jobs. :type constant: Integer :returns: Object """ return self._bind_check(bind=self.bind_job, interval=interval, constant=constant) def shutdown(self): """Shutdown the driver.""" if hasattr(self.ctx, "close"): self.ctx.close() if hasattr(self._context, "close"): self._context.close() self.job_close() self.backend_close()
class StratusApp(StratusServerApp): def __init__(self, core: StratusCore, **kwargs): StratusServerApp.__init__(self, core, **kwargs) self.logger = StratusLogger.getLogger() self.active = True self.parms = self.getConfigParms('stratus') self.client_address = self.parms.get("client_address", "*") self.request_port = self.parms.get("request_port", 4556) self.response_port = self.parms.get("response_port", 4557) self.active_handlers = {} self.getCertDirs() def getCertDirs( self ): # These directories are generated by the generate_certificates script self.cert_dir = self.parms.get("certificate_path", os.path.expanduser("~/.stratus/zmq")) self.logger.info( f"Loading certificates and keys from directory {self.cert_dir}") self.keys_dir = os.path.join(self.cert_dir, 'certificates') self.public_keys_dir = os.path.join(self.cert_dir, 'public_keys') self.secret_keys_dir = os.path.join(self.cert_dir, 'private_keys') if not (os.path.exists(self.keys_dir) and os.path.exists(self.public_keys_dir) and os.path.exists(self.secret_keys_dir)): from stratus.handlers.zeromq.security.generate_certificates import generate_certificates generate_certificates(self.cert_dir) def initSocket(self): try: server_secret_file = os.path.join(self.secret_keys_dir, "server.key_secret") server_public, server_secret = zmq.auth.load_certificate( server_secret_file) # TODO: this is commented to avoid key checking #self.request_socket.curve_secretkey = server_secret #self.request_socket.curve_publickey = server_public #self.request_socket.curve_server = True self.request_socket.bind("tcp://{}:{}".format( self.client_address, self.request_port)) self.logger.info( "@@STRATUS-APP --> Bound authenticated request socket to client at {} on port: {}" .format(self.client_address, self.request_port)) except Exception as err: self.logger.error( "@@STRATUS-APP: Error initializing request socket on {}, port {}: {}" .format(self.client_address, self.request_port, err)) self.logger.error(traceback.format_exc()) def addHandler(self, clientId, jobId, handler): self.active_handlers[clientId + "-" + jobId] = handler return handler def removeHandler(self, clientId, jobId): handlerId = clientId + "-" + jobId try: del self.active_handlers[handlerId] except: self.logger.error("Error removing handler: " + handlerId + ", active handlers = " + str(list(self.active_handlers.keys()))) def setExeStatus(self, submissionId: str, status: Status): self.responder.setExeStatus(submissionId, status) def sendResponseMessage(self, msg: StratusResponse) -> str: request_args = [msg.id, msg.message] packaged_msg = "!".join(request_args) timeStamp = datetime.datetime.now().strftime("MM/dd HH:mm:ss") self.logger.info( "@@STRATUS-APP: Sending response {} on request_socket @({}): {}". format(msg.id, timeStamp, str(msg))) self.request_socket.send_string(packaged_msg) return packaged_msg def initInteractions(self): try: self.zmqContext: zmq.Context = zmq.Context() self.auth = ThreadAuthenticator(self.zmqContext) self.auth.start() self.auth.allow("192.168.0.22") self.auth.allow(self.client_address) self.auth.configure_curve( domain='*', location=zmq.auth.CURVE_ALLOW_ANY ) # self.public_keys_dir ) # Use 'location=zmq.auth.CURVE_ALLOW_ANY' for stonehouse security self.request_socket: zmq.Socket = self.zmqContext.socket(zmq.REP) self.responder = StratusZMQResponder( self.zmqContext, self.response_port, client_address=self.client_address, certificate_path=self.cert_dir) self.initSocket() self.logger.info( "@@STRATUS-APP:Listening for requests on port: {}".format( self.request_port)) except Exception as err: self.logger.error( "@@STRATUS-APP: ------------------------------- StratusApp Init error: {} ------------------------------- " .format(err)) def processResults(self): completed_workflows = self.responder.processWorkflows( self.getWorkflows()) for rid in completed_workflows: self.clearWorkflow(rid) def processRequests(self): while self.request_socket.poll(0) != 0: request_header = self.request_socket.recv_string().strip().strip( "'") parts = request_header.split("!") submissionId = str(parts[0]) rType = str(parts[1]) request: Dict = json.loads(parts[2]) if len(parts) > 2 else "" try: self.logger.info( "@@STRATUS-APP: ### Processing {} request: {}".format( rType, request)) if rType == "capabilities": response = self.core.getCapabilities(request["type"]) self.sendResponseMessage( StratusResponse(submissionId, response)) elif rType == "exe": if len(parts) <= 2: raise Exception("Missing parameters to exe request") request["rid"] = submissionId self.logger.info( "Processing zmq Request: '{}' '{}' '{}'".format( submissionId, rType, str(request))) self.submitWorkflow( request) # TODO: Send results when tasks complete. response = {"status": "Executing"} self.sendResponseMessage( StratusResponse(submissionId, response)) elif rType == "quit" or rType == "shutdown": response = {"status": "Terminating"} self.sendResponseMessage( StratusResponse(submissionId, response)) self.logger.info( "@@STRATUS-APP: Received Shutdown Message") exit(0) else: msg = "@@STRATUS-APP: Unknown request type: " + rType self.logger.info(msg) response = {"status": "error", "error": msg} self.sendResponseMessage( StratusResponse(submissionId, response)) except Exception as ex: self.processError(submissionId, ex) def processError(self, rid: str, ex: Exception): tb = traceback.format_exc() self.logger.error("@@STRATUS-APP: Execution error: " + str(ex)) self.logger.error(tb) response = {"status": "error", "error": str(ex), "traceback": tb} self.sendResponseMessage(StratusResponse(rid, response)) def updateInteractions(self): self.processRequests() self.processResults() def term(self, msg): self.logger.info("@@STRATUS-APP: !!EDAS Shutdown: " + msg) self.active = False self.auth.stop() self.logger.info("@@STRATUS-APP: QUIT PythonWorkerPortal") try: self.request_socket.close() except Exception: pass self.logger.info("@@STRATUS-APP: CLOSE request_socket") self.responder.close_connection() self.logger.info("@@STRATUS-APP: TERM responder") self.shutdown() self.logger.info("@@STRATUS-APP: shutdown complete")
def _run(self, _, frontend, sink, *backend_socks): def push_new_job(_job_id, _json_msg, _msg_len): # backend_socks[0] is always at the highest priority _sock = backend_socks[ 0] if _msg_len <= self.args.priority_batch_size else rand_backend_socket _sock.send_multipart([_job_id, _json_msg]) # bind all sockets self.logger.info('bind all sockets') if self.certs_path is not None: base_dir = self.certs_path public_keys_dir = os.path.join(base_dir, 'public_keys') secret_keys_dir = os.path.join(base_dir, 'private_keys') server_secret_file = os.path.join(secret_keys_dir, "server.key_secret") if not (os.path.exists(base_dir) and os.path.exists(public_keys_dir) and os.path.exists(secret_keys_dir)): self.logger.critical( "No certificates dirs found in %s directory" % base_dir) raise Exception("No certificates dirs found in %s directory" % base_dir) server_public, server_secret = zmq.auth.load_certificate( server_secret_file) frontend.curve_secretkey = server_secret frontend.curve_publickey = server_public frontend.curve_server = True self.args.server_public = server_public self.args.server_secret = server_secret if self.args.server_secret is None or self.args.server_public is None: self.logger.critical( "No certificates found in %s and %s directories" % (public_keys_dir, secret_keys_dir)) raise Exception( "No certificates found in %s and %s directories" % (public_keys_dir, secret_keys_dir)) if self.allowed_public_keys_dir is not None: auth = ThreadAuthenticator(frontend.context) auth.start() # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=self.allowed_public_keys_dir) frontend.bind('tcp://*:%d' % self.port) addr_front2sink = auto_bind(sink) addr_backend_list = [auto_bind(b) for b in backend_socks] self.logger.info('open %d ventilator-worker sockets' % len(addr_backend_list)) # start the sink process self.logger.info('start the sink') proc_sink = BertSink(self.args, addr_front2sink, self.bert_config) self.processes.append(proc_sink) proc_sink.start() addr_sink = sink.recv().decode('ascii') # start the backend processes device_map = self._get_device_map() for idx, device_id in enumerate(device_map): process = BertWorker(idx, self.args, addr_backend_list, addr_sink, device_id, self.graph_path, self.bert_config) self.processes.append(process) process.start() self.logger.info('all set, ready to serve request! %s' % device_map) # start the http-service process if self.args.http_port: self.logger.info('start http proxy') proc_proxy = BertHTTPProxy(self.args) self.processes.append(proc_proxy) proc_proxy.start() rand_backend_socket = None server_status = ServerStatistic() for p in self.processes: p.is_ready.wait() self.is_ready.set() self.logger.info('all set, ready to serve request!') while True: try: request = frontend.recv_multipart() client, msg, req_id, msg_len = request assert req_id.isdigit() assert msg_len.isdigit() except (ValueError, AssertionError): self.logger.error( 'received a wrongly-formatted request (expected 4 frames, got %d)' % len(request)) self.logger.error('\n'.join('field %d: %s' % (idx, k) for idx, k in enumerate(request)), exc_info=True) else: server_status.update(request) if msg == ServerCmd.terminate: break elif msg == ServerCmd.show_config: self.logger.info( 'new config request\treq id: %d\tclient: %s' % (int(req_id), client)) status_runtime = { 'client': client.decode('ascii'), 'num_process': len(self.processes), 'ventilator -> worker': addr_backend_list, 'worker -> sink': addr_sink, 'ventilator <-> sink': addr_front2sink, 'server_current_time': str(datetime.now()), 'statistic': server_status.value, 'device_map': device_map, 'num_concurrent_socket': self.num_concurrent_socket } sink.send_multipart([ client, msg, jsonapi.dumps({ **status_runtime, **self.status_args, **self.status_static }), req_id ]) else: self.logger.info( 'new encode request\treq id: %d\tsize: %d\tclient: %s' % (int(req_id), int(msg_len), client)) # register a new job at sink sink.send_multipart( [client, ServerCmd.new_job, msg_len, req_id]) # renew the backend socket to prevent large job queueing up # [0] is reserved for high priority job # last used backennd shouldn't be selected either as it may be queued up already rand_backend_socket = random.choice([ b for b in backend_socks[1:] if b != rand_backend_socket ]) # push a new job, note super large job will be pushed to one socket only, # leaving other sockets free job_id = client + b'#' + req_id if int(msg_len) > self.max_batch_size: seqs = jsonapi.loads(msg) job_gen = ((job_id + b'@%d' % i, seqs[i:(i + self.max_batch_size)]) for i in range(0, int(msg_len), self.max_batch_size)) for partial_job_id, job in job_gen: push_new_job(partial_job_id, jsonapi.dumps(job), len(job)) else: push_new_job(job_id, msg, int(msg_len)) for p in self.processes: p.close() self.logger.info('terminated!')
class FrankFancyStreamingInterface(object): """ Abstraction layer to the graph streamer as well as the central logger Uses direct (non encrypted) socket connection to the streaming server It uses an (encrypted) zeromq connection to the logger """ ConvertStatus = { "Cells" : { 0 : 5, #removing 1 : 4, #allocating 2 : 6 #blacklisting } } #TODO: give every scheduler an unique topic to easily distinguish between them on the queue def __init__(self, name, privatekey, VisualizerHost, root_id, ZeromqHost = "*", empty=False): """ Calls internal methods to open the connections to both the Active Live visualizer and the logger :param VisualizerHost: The ip of the FrankFancyGraphStreamer :type VisualizerHost: str :param ZeromqHost: which interface the zeromq service needs to bind too ("*" for all interfaces) :type ZeromqHost: str :param KeyFolder: The folder with all the keys, as generated by generate_certificates.py :type KeyFolder: str :param root_id: the root of the network: LBR :type root_id: str :return: """ self.Active = None self.Logger = None self.EventId = 0 self.Name = name #used as topic on the queue if not empty: if privatekey is not None: self._connectLogger(privatekey, Host=ZeromqHost) if VisualizerHost is not None: self._connectVisualizer(VisualizerHost, root_id) self.g = DoDAG(root_id, root_id) self.root_id = root_id def _connectVisualizer(self, Host, root_id): """ Connect to the Active Live Visualizer :param Host: The ip of the FrankFancyGraphStreamer :param root_id: the ip6 address of the root node of the network :return: """ try: logg.debug("Connecting Streaming Interface to Active Viewer") self.Active = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.Active.connect((Host, 600)) logg.debug("Sending to Active Viewer:{}".format(root_id)) self.Active.sendall(root_id) except: logg.debug("Connection to Active Viewer failed!") self.Active = None def _connectLogger(self, key, Host="localhost"): """ Open a zeromq queue with publisher service :param Host: which interface the zeromq service needs to bind too ("*" for all interfaces) :param key: privatekey file of the scheduler :return: """ #TODO: error handling on certificates missing and stuff #TODO: expose more security options such as white/blacklisting ips and domain filtering self.context = zmq.Context() self.auth = ThreadAuthenticator(self.context) self.auth.start() self.auth.configure_curve(domain='*', location=os.path.join("keys", "public")) self.Logger = self.context.socket(zmq.PUB) scheduler_public, scheduler_secret = zmq.auth.load_certificate(os.path.join("keys", "plexi1.key_secret")) self.Logger.curve_secretkey = scheduler_secret self.Logger.curve_publickey = scheduler_public self.Logger.curve_server = True self.Logger.bind("tcp://127.0.0.1:6000") # raw_input("Press enter when the logger has opened subscription to us") def SendActiveJson(self,data): """ Sends an object as json encoded to the Active Live Viewer :param data: the object to be send :return: """ if self.Active is not None: logg.debug("Sending json data to Active: " + json.dumps(data)) self.Active.sendall(json.dumps(data)) def PublishLogging(self,LoggingName="zmq.auth", root_topic="zmq.auth"): """ Publishes the given python logger to the publishing service :param LoggingName: Name of the python logger service :type LoggingName: str :param root_topic: the topic given with message. is appended with .<LEVEL> :type root_topic: str :return: """ handler = PUBHandler(self.Logger) handler.root_topic = root_topic handler.formatters[logging.DEBUG] = logging.Formatter(fmt='%(asctime)s\t%(levelname)s: %(message)s', datefmt='%H:%M:%S') handler.formatters[logging.INFO] = logging.Formatter(fmt='%(asctime)s\t%(levelname)s: %(message)s', datefmt='%H:%M:%S') l = logging.getLogger(LoggingName) l.addHandler(handler) def ChangeCell(self, who, slotoffs, channeloffs, frame, ID, status): """ Notifies all active services about the changes to a cell in the schedule matrix :param who: The node in which the cell is changed :type who: :class: `node.NodeID` :param slotoffs: slot offset :param channeloffs: channel offset :param frame: frame name :param ID: local cell id :param status: new status of the cell :return: """ if self.Active is not None: logg.debug("Sending ChangeCell to active viewer") self.Active.sendall(json.dumps(["changecell",{"who": str(who), "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID, "status":status}])) if self.Logger is not None: self.EventId += 1 logg.debug("Sending ChangeCell to logger, EventID:" + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : self.ConvertStatus["Cells"][status], # "InfoString" : json.dumps({"who": who, "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID}) # })]) self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, self.ConvertStatus["Cells"][status], time.time(), json.dumps({"node_id": str(who), "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID})))]) def DumpDotData(self, labels={}): """ dumps an entire dot file to the active viewer. This is not used for the logger :return: """ # packet = "[\"" + str(self.root_id) + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"," + json.dumps(dotdata) + "]" if self.Active is not None: logg.debug("Sending dotdata") # self.Active.sendall(bytearray("[\"" + root_id + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"," + dotdata + "]")) dotdata = self.g.draw_graph(labels=labels) self.Active.sendall(bytearray(json.dumps(["\"" + self.root_id + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"", dotdata]))) time.sleep(.5) def AddNode(self, node_id, parent): """ Sends a notification of joining node to the logger :param node_id: ip6 of the node :type node_id: str :param parent: ip6 of the parent node :type parent: str :return: """ node_id = str(node_id) parent = str(parent) if self.Logger is not None: self.EventId += 1 logg.debug("Sending Addnode to logger, EventID:" + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : 0, # "InfoString": json.dumps({"node_id" : node_id, "parent" : parent}) # })]) self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 0, time.time(), json.dumps({"node_id" : str(node_id), "parent" : str(parent)})))]) if self.Active is not None: logg.debug("Sending Addnode to Active Visualizer, node:{}, parent:{}".format(node_id, parent)) if parent == "root": self.g.attach_node(node_id) else: self.g.attach_child(node_id, parent) self.DumpDotData() def RewireNode(self, node_id, old_parent, new_parent): """ Notifies the logger of a rewire that happened in the network :param node_id: ip6 of the node that has rewired :param old_parent: ip6 of the old parent :param new_parent: ip6 of the new parent :return: """ node_id = str(node_id) old_parent = str(old_parent) new_parent = str(new_parent) if self.Logger is not None: self.EventId += 1 logg.debug("Sending RewireNode to logger, EventID: " + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : 2, # "InfoString": json.dumps({"node_id" : node_id, "old_parent" : old_parent, "new_parent" : new_parent}) # })]) self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 2, time.time(), json.dumps({"node_id" : str(node_id), "old_parent" : str(old_parent), "new_parent" : str(new_parent)})))]) if self.Active is not None: logg.debug("Sending Rewire to the Active") self.g.attach_child(node_id, new_parent) self.DumpDotData() def RemoveNode(self, node_id): """ Notifies the logger of a disconnected node :param node_id: ip6 of the node that has disconnected :return: """ node_id = str(node_id) if self.Logger is not None: self.EventId += 1 logg.debug("Sending RemoveNode to logger, EventID: " + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : 1, # "InfoString": json.dumps({"node_id" : node_id}) # })]) self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 1, time.time(), json.dumps({"node_id" : str(node_id)})))]) if self.Active is not None: self.g.detach_node(node_id) self.DumpDotData() def RegisterFrame(self, num_cells, framename): """ Notifies the logger of a new frame that is defined in the scheduler algorithm :param num_cells: number of cells per channel :param framename: unique identifieng name :return: """ if self.Logger is not None: self.EventId += 1 logg.debug("Sending RegisterFrame to logger, EventID: " + str(self.EventId)) self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 7, time.time(), json.dumps({"cells" : num_cells, "name" : framename})))]) def RegisterFrames(self, frames): if self.Active is not None: logg.debug("Sending RegisterFrames to Active") self.Active.sendall(bytearray(json.dumps(frames)))
class ZmqListener: def __init__(self, settings): self.redis = RedisScraper(settings) self.id = settings.getKey("box_id") self.log = logging.getLogger('ZMQ') self.clientPath = settings.getKey("zmq.private_cert") self.serverPath = settings.getKey("zmq.server_cert") if not self.clientPath or not self.serverPath: self.log.fatal( "zmq certificates not configured in the settings file") os._exit(1) self.host = settings.getKey("zmq.acq_host") self.ctx = zmq.Context() self.auth = ThreadAuthenticator(self.ctx) self.auth.start() #self.auth.allow('127.0.0.1') self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) self.client = self.ctx.socket(zmq.REP) try: client_public, client_secret = zmq.auth.load_certificate( self.clientPath) self.client.curve_secretkey = client_secret self.client.curve_publickey = client_public server_public, _ = zmq.auth.load_certificate(self.serverPath) self.client.curve_serverkey = server_public self.client.connect(self.host) except IOError: self.log.fatal("Could not load client certificate") os._exit(1) except ValueError: self.log.fatal("Could not load client certificate") os._exit(1) self.log.info("ZMQ connected to " + self.host + " using certs " + self.clientPath) self.running = False self.handlers = { opq_pb2.RequestDataMessage.PING: self.ping, opq_pb2.RequestDataMessage.READ: self.read } def ping(self, message): message.type = opq_pb2.RequestDataMessage.PONG message_buff = message.SerializeToString() self.log.info("Received a PING from server") self.client.send(message_buff) return True def read(self, message): self.log.debug("Received a data transfer request from server") try: if message.front == 0 or message.back == 0: message.type = opq_pb2.RequestDataMessage.ERROR message_buff = message.SerializeToString() self.log.info("Bad message from server") self.client.send(message_buff) return False cycles = self.redis.getRange(message.time - message.back, message.time + message.front) cycles.id = self.id cycles.mid = message.mid message_buff = cycles.SerializeToString() self.client.send(message_buff) except google.protobuf.message.DecodeError: self.log.fatal("Bad request from acquisition server.") return False def run(self): self.running = True try: while self.running: message_buff = self.client.recv() message = opq_pb2.RequestDataMessage() message.ParseFromString(message_buff) self.handlers[message.type](message) except google.protobuf.message.DecodeError: self.log.fatal("Bad request from acquisition server.")
def test_encryption(tmpdir): # Create the tmp names conf_filename = str(tmpdir.join("conf.yaml")) pull_url = tmpdir.join("input.pull.socket") pull_cert_dir = tmpdir.mkdir("input.pull") pull_clients_cert_dir = pull_cert_dir.mkdir("clients") sub_url = tmpdir.join("input.sub.socket") sub_cert_dir = tmpdir.mkdir("input.sub") push_url = tmpdir.join("output.push.socket") inbound = tmpdir.join("inbound") outbound = tmpdir.join("outbound") stdout = tmpdir.join("stdout") stderr = tmpdir.join("stderr") # Create the certificates create_certificates(str(pull_cert_dir), "pull") create_certificates(str(pull_clients_cert_dir), "client1") create_certificates(str(pull_clients_cert_dir), "client2") create_certificates(str(sub_cert_dir), "sub") create_certificates(str(sub_cert_dir), "sub-server") with open(conf_filename, "w") as f: f.write("inputs:\n") f.write("- class: ZMQPull\n") f.write(" name: in-pull\n") f.write(" options:\n") f.write(" url: ipc://%s\n" % pull_url) f.write(" encryption:\n") f.write(" self: %s\n" % pull_cert_dir.join("pull.key_secret")) f.write(" clients: %s\n" % pull_clients_cert_dir) f.write("- class: ZMQSub\n") f.write(" name: in-sub\n") f.write(" options:\n") f.write(" url: ipc://%s\n" % sub_url) f.write(" encryption:\n") f.write(" self: %s\n" % sub_cert_dir.join("sub.key_secret")) f.write(" server: %s\n" % sub_cert_dir.join("sub-server.key")) f.write("core:\n") f.write(" inbound: ipc://%s\n" % inbound) f.write(" outbound: ipc://%s\n" % outbound) f.write("outputs:\n") f.write("- class: ZMQPush\n") f.write(" name: out-push\n") f.write(" options:\n") f.write(" url: ipc://%s\n" % push_url) args = [ "python3", "-m", "reactobus", "--conf", conf_filename, "--level", "DEBUG", "--log-file", "-", ] proc = subprocess.Popen(args, stdout=open(str(stdout), "w"), stderr=open(str(stderr), "w")) # Create the input sockets ctx = zmq.Context.instance() in_sock = ctx.socket(zmq.PUSH) (server_public, _) = load_certificate(str(pull_cert_dir.join("pull.key"))) in_sock.curve_serverkey = server_public (client_public, client_private) = load_certificate( str(pull_clients_cert_dir.join("client1.key_secret"))) in_sock.curve_publickey = client_public in_sock.curve_secretkey = client_private in_sock.connect("ipc://%s" % pull_url) out_sock = ctx.socket(zmq.PULL) out_sock.bind("ipc://%s" % push_url) pub_sock = ctx.socket(zmq.PUB) auth = ThreadAuthenticator(ctx) auth.start() auth.configure_curve(domain="*", location=str(sub_cert_dir)) (server_public, server_secret) = load_certificate( str(sub_cert_dir.join("sub-server.key_secret"))) pub_sock.curve_publickey = server_public pub_sock.curve_secretkey = server_secret pub_sock.curve_server = True pub_sock.bind("ipc://%s" % sub_url) # Allow the process sometime to setup and connect time.sleep(1) # Send some data data = [ b"org.videolan.git", b(str(uuid.uuid1())), b(datetime.datetime.utcnow().isoformat()), b("videolan-git"), b( json.dumps({ "url": "https://code.videolan.org/éêï", "username": "******" })), ] in_sock.send_multipart(data) msg = out_sock.recv_multipart() assert msg == data data = [ b"org.videolan.git", b(str(uuid.uuid1())), b(datetime.datetime.utcnow().isoformat()), b("videolan-git"), b( json.dumps({ "url": "https://code.videolan.org/éêï", "username": "******" })), ] pub_sock.send_multipart(data) msg = out_sock.recv_multipart() assert msg == data # End the process proc.terminate() proc.wait()
class RpcClient: """""" def __init__(self): """Constructor""" # zmq port related self.__context: zmq.Context = zmq.Context() # Request socket (Request–reply pattern) self.__socket_req: zmq.Socket = self.__context.socket(zmq.REQ) # Subscribe socket (Publish–subscribe pattern) self.__socket_sub: zmq.Socket = self.__context.socket(zmq.SUB) # Worker thread relate, used to process data pushed from server self.__active: bool = False # RpcClient status self.__thread: threading.Thread = None # RpcClient thread self.__lock: threading.Lock = threading.Lock() # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None self._last_received_ping: datetime = datetime.utcnow() @lru_cache(100) def __getattr__(self, name: str): """ Realize remote call function """ # Perform remote call task def dorpc(*args, **kwargs): # Get timeout value from kwargs, default value is 30 seconds if "timeout" in kwargs: timeout = kwargs.pop("timeout") else: timeout = 30000 # Generate request req = [name, args, kwargs] # Send request and wait for response with self.__lock: self.__socket_req.send_pyobj(req) # Timeout reached without any data n = self.__socket_req.poll(timeout) if not n: msg = f"Timeout of {timeout}ms reached for {req}" raise RemoteException(msg) rep = self.__socket_req.recv_pyobj() # Return response if successed; Trigger exception if failed if rep[0]: return rep[1] else: raise RemoteException(rep[1]) return dorpc def start(self, req_address: str, sub_address: str, client_secretkey_path: str = "", server_publickey_path: str = "", username: str = "", password: str = "") -> None: """ Start RpcClient """ if self.__active: return # Start authenticator if client_secretkey_path and server_publickey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( client_secretkey_path) serverkey, _ = zmq.auth.load_certificate(server_publickey_path) self.__socket_sub.curve_secretkey = secretkey self.__socket_sub.curve_publickey = publickey self.__socket_sub.curve_serverkey = serverkey self.__socket_req.curve_secretkey = secretkey self.__socket_req.curve_publickey = publickey self.__socket_req.curve_serverkey = serverkey elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_sub.plain_username = username.encode() self.__socket_sub.plain_password = password.encode() self.__socket_req.plain_username = username.encode() self.__socket_req.plain_password = password.encode() # Connect zmq port self.__socket_req.connect(req_address) self.__socket_sub.connect(sub_address) # Start RpcClient status self.__active = True # Start RpcClient thread self.__thread = threading.Thread(target=self.run) self.__thread.start() self._last_received_ping = datetime.utcnow() def stop(self) -> None: """ Stop RpcClient """ if not self.__active: return # Stop RpcClient status self.__active = False def join(self) -> None: # Wait for RpcClient thread to exit if self.__thread and self.__thread.is_alive(): self.__thread.join() self.__thread = None def run(self) -> None: """ Run RpcClient function """ pull_tolerance = int(KEEP_ALIVE_TOLERANCE.total_seconds() * 1000) while self.__active: if not self.__socket_sub.poll(pull_tolerance): self.on_disconnected() continue # Receive data from subscribe socket topic, data = self.__socket_sub.recv_pyobj(flags=NOBLOCK) if topic == KEEP_ALIVE_TOPIC: self._last_received_ping = data else: # Process data by callable function self.callback(topic, data) # Close socket self.__socket_req.close() self.__socket_sub.close() def callback(self, topic: str, data: Any) -> None: """ Callable function """ raise NotImplementedError def subscribe_topic(self, topic: str) -> None: """ Subscribe data """ self.__socket_sub.setsockopt_string(zmq.SUBSCRIBE, topic) def on_disconnected(self): """ Callback when heartbeat is lost. """ print( "RpcServer has no response over {tolerance} seconds, please check you connection." .format(tolerance=KEEP_ALIVE_TOLERANCE.total_seconds()))
class Command(LAVADaemonCommand): help = "LAVA log recorder" logger = None default_logfile = "/var/log/lava-server/lava-logs.log" def __init__(self, *args, **options): super().__init__(*args, **options) self.logger = logging.getLogger("lava-logs") self.log_socket = None self.auth = None self.controler = None self.inotify_fd = None self.pipe_r = None self.poller = None self.cert_dir_path = None # List of logs self.jobs = {} # Keep test cases in memory self.test_cases = [] # Master status self.last_ping = 0 self.ping_interval = TIMEOUT def add_arguments(self, parser): super().add_arguments(parser) net = parser.add_argument_group("network") net.add_argument('--socket', default='tcp://*:5555', help="Socket waiting for logs. Default: tcp://*:5555") net.add_argument('--master-socket', default='tcp://localhost:5556', help="Socket for master-slave communication. Default: tcp://localhost:5556") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument('--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-logs", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return filename = os.path.join(settings.MEDIA_ROOT, 'lava-logs-config.yaml') self.logger.debug("[INIT] Dumping config to %s", filename) with open(filename, 'w') as output: yaml.dump(options, output) # Create the sockets context = zmq.Context() self.log_socket = context.socket(zmq.PULL) self.controler = context.socket(zmq.ROUTER) self.controler.setsockopt(zmq.IDENTITY, b"lava-logs") # Limit the number of messages in the queue self.controler.setsockopt(zmq.SNDHWM, 2) # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5 # "Immediately readies that connection for data transfer with the master" self.controler.setsockopt(zmq.CONNECT_RID, b"master") if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.log_socket.setsockopt(zmq.IPV6, 1) self.controler.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate(options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except OSError as err: self.logger.error("[INIT] %s", err) self.auth.stop() return self.log_socket.curve_publickey = master_public self.log_socket.curve_secretkey = master_secret self.log_socket.curve_server = True self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_serverkey = master_public self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.cert_dir_path = options["slaves_certs"] self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.log_socket.bind(options['socket']) self.controler.connect(options['master_socket']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.log_socket, zmq.POLLIN) self.poller.register(self.controler, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] listening for logs") # PING right now: the master is waiting for this message to start # scheduling. self.controler.send_multipart([b"master", b"PING"]) try: self.main_loop() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) # Close the controler socket self.controler.close(linger=0) self.poller.unregister(self.controler) # Carefully close the logging socket as we don't want to lose messages self.logger.info("[EXIT] Disconnect logging socket and process messages") endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT)) self.logger.debug("[EXIT] unbinding from '%s'", endpoint) self.log_socket.unbind(endpoint) # Empty the queue try: while self.wait_for_messages(True): # Flush test cases cache for every iteration because we might # get killed soon. self.flush_test_cases() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Last flush self.flush_test_cases() self.logger.info("[EXIT] Closing the logging socket: the queue is empty") self.log_socket.close() if options['encrypt']: self.auth.stop() context.term() def flush_test_cases(self): if not self.test_cases: return # Try to save into the database try: TestCase.objects.bulk_create(self.test_cases) self.logger.info("Saving %d test cases", len(self.test_cases)) self.test_cases = [] except DatabaseError as exc: self.logger.error("Unable to flush the test cases") self.logger.exception(exc) self.logger.warning("Saving test cases one by one and dropping the faulty ones") saved = 0 for tc in self.test_cases: with contextlib.suppress(DatabaseError): tc.save() saved += 1 self.logger.info("%d test cases saved, %d dropped", saved, len(self.test_cases) - saved) self.test_cases = [] def main_loop(self): last_gc = time.time() last_bulk_create = time.time() # Wait for messages # TODO: fix timeout computation while self.wait_for_messages(False): now = time.time() # Dump TestCase into the database if now - last_bulk_create > BULK_CREATE_TIMEOUT: last_bulk_create = now self.flush_test_cases() # Close old file handlers if now - last_gc > FD_TIMEOUT: last_gc = now # Iterate while removing keys is not compatible with iterator for job_id in list(self.jobs.keys()): # pylint: disable=consider-iterating-dictionary if now - self.jobs[job_id].last_usage > FD_TIMEOUT: self.logger.info("[%s] closing log file", job_id) self.jobs[job_id].close() del self.jobs[job_id] # Ping the master if now - self.last_ping > self.ping_interval: self.logger.debug("PING => master") self.last_ping = now self.controler.send_multipart([b"master", b"PING"]) def wait_for_messages(self, leaving): try: try: sockets = dict(self.poller.poll(TIMEOUT * 1000)) except zmq.error.ZMQError as exc: self.logger.error("[POLL] zmq error: %s", str(exc)) return True # Messages if sockets.get(self.log_socket) == zmq.POLLIN: self.logging_socket() return True # Signals elif sockets.get(self.pipe_r) == zmq.POLLIN: # remove the message from the queue os.read(self.pipe_r, 1) if not leaving: self.logger.info("[POLL] received a signal, leaving") return False else: self.logger.warning("[POLL] signal already handled, please wait for the process to exit") return True # Pong received elif sockets.get(self.controler) == zmq.POLLIN: self.controler_socket() return True # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", self.cert_dir_path) self.auth.configure_curve(domain='*', location=self.cert_dir_path) # Nothing received else: return not leaving except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset") connection.close() return True def logging_socket(self): msg = self.log_socket.recv_multipart() try: (job_id, message) = (u(m) for m in msg) # pylint: disable=unbalanced-tuple-unpacking except ValueError: # do not let a bad message stop the master. self.logger.error("[POLL] failed to parse log message, skipping: %s", msg) return try: scanned = yaml.load(message, Loader=yaml.CLoader) except yaml.YAMLError: self.logger.error("[%s] data are not valid YAML, dropping", job_id) return # Look for "results" level try: message_lvl = scanned["lvl"] message_msg = scanned["msg"] except TypeError: self.logger.error("[%s] not a dictionary, dropping", job_id) return except KeyError: self.logger.error( "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s", job_id, message) return # Find the handler (if available) if job_id not in self.jobs: # Query the database for the job try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return self.logger.info("[%s] receiving logs from a new job", job_id) # Create the sub directories (if needed) mkdir(job.output_dir) self.jobs[job_id] = JobHandler(job) # For 'event', send an event and log as 'debug' if message_lvl == 'event': self.logger.debug("[%s] event: %s", job_id, message_msg) send_event(".event", "lavaserver", {"message": message_msg, "job": job_id}) message_lvl = "debug" # For 'marker', save in the database and log as 'debug' elif message_lvl == 'marker': # TODO: save on the file system in case of lava-logs restart m_type = message_msg.get("type") case = message_msg.get("case") if m_type is None or case is None: self.logger.error("[%s] invalid marker: %s", job_id, message_msg) return self.jobs[job_id].markers.setdefault(case, {})[m_type] = self.jobs[job_id].line_count() # This is in fact the previous line self.jobs[job_id].markers[case][m_type] -= 1 self.logger.debug("[%s] marker: %s line: %s", job_id, message_msg, self.jobs[job_id].markers[case][m_type]) return # Mark the file handler as used self.jobs[job_id].last_usage = time.time() # The format is a list of dictionaries self.jobs[job_id].write("- %s" % message) if message_lvl == "results": try: job = TestJob.objects.get(pk=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return meta_filename = create_metadata_store(message_msg, job) new_test_case = map_scanned_results(results=message_msg, job=job, markers=self.jobs[job_id].markers, meta_filename=meta_filename) if new_test_case is None: self.logger.warning( "[%s] unable to map scanned results: %s", job_id, message) else: self.test_cases.append(new_test_case) # Look for lava.job result if message_msg.get("definition") == "lava" and message_msg.get("case") == "job": # Flush cached test cases self.flush_test_cases() if message_msg.get("result") == "pass": health = TestJob.HEALTH_COMPLETE health_msg = "Complete" else: health = TestJob.HEALTH_INCOMPLETE health_msg = "Incomplete" self.logger.info("[%s] job status: %s", job_id, health_msg) infrastructure_error = (message_msg.get("error_type") in ["Bug", "Configuration", "Infrastructure"]) if infrastructure_error: self.logger.info("[%s] Infrastructure error", job_id) # Update status. with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(health, infrastructure_error) job.save() # n.b. logging here would produce a log entry for every message in every job. def controler_socket(self): msg = self.controler.recv_multipart() try: master_id = u(msg[0]) action = u(msg[1]) ping_interval = int(msg[2]) if master_id != "master": self.logger.error("Invalid master id '%s'. Should be 'master'", master_id) return if action != "PONG": self.logger.error("Invalid answer '%s'. Should be 'PONG'", action) return except (IndexError, ValueError): self.logger.error("Invalid message '%s'", msg) return if ping_interval < TIMEOUT: self.logger.error("invalid ping interval (%d) too small", ping_interval) return self.logger.debug("master => PONG(%d)", ping_interval) self.ping_interval = ping_interval
class RpcServer: """""" def __init__(self) -> None: """ Constructor """ # Save functions dict: key is function name, value is function object self._functions: Dict[str, Callable] = {} # Zmq port related self._context: zmq.Context = zmq.Context() # Reply socket (Request–reply pattern) self._socket_rep: zmq.Socket = self._context.socket(zmq.REP) # Publish socket (Publish–subscribe pattern) self._socket_pub: zmq.Socket = self._context.socket(zmq.PUB) # Worker thread related self._active: bool = False # RpcServer status self._thread: threading.Thread = None # RpcServer thread self._lock: threading.Lock = threading.Lock() # Heartbeat related self._heartbeat_at: int = None # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None def is_active(self) -> bool: """""" return self._active def start(self, rep_address: str, pub_address: str, username: str = "", password: str = "", server_secretkey_path: str = "") -> None: """ Start RpcServer """ if self._active: return # Start authenticator if server_secretkey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( server_secretkey_path) self.__socket_pub.curve_secretkey = secretkey self.__socket_pub.curve_publickey = publickey self.__socket_pub.curve_server = True self.__socket_rep.curve_secretkey = secretkey self.__socket_rep.curve_publickey = publickey self.__socket_rep.curve_server = True elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_pub.plain_server = True self.__socket_rep.plain_server = True # Bind socket address self._socket_rep.bind(rep_address) self._socket_pub.bind(pub_address) # Start RpcServer status self._active = True # Start RpcServer thread self._thread = threading.Thread(target=self.run) self._thread.start() # Init heartbeat publish timestamp self._heartbeat_at = time() + HEARTBEAT_INTERVAL def stop(self) -> None: """ Stop RpcServer """ if not self._active: return # Stop RpcServer status self._active = False def join(self) -> None: # Wait for RpcServer thread to exit if self._thread and self._thread.is_alive(): self._thread.join() self._thread = None def run(self) -> None: """ Run RpcServer functions """ while self._active: # Poll response socket for 1 second n: int = self._socket_rep.poll(1000) self.check_heartbeat() if not n: continue # Receive request data from Reply socket req = self._socket_rep.recv_pyobj() # Get function name and parameters name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails try: func: Callable = self._functions[name] r: Any = func(*args, **kwargs) rep: list = [True, r] except Exception as e: # noqa rep: list = [False, traceback.format_exc()] # send callable response by Reply socket self._socket_rep.send_pyobj(rep) # Unbind socket address self._socket_pub.unbind(self._socket_pub.LAST_ENDPOINT) self._socket_rep.unbind(self._socket_rep.LAST_ENDPOINT) if self.__authenticator: self.__authenticator.stop() def publish(self, topic: str, data: Any) -> None: """ Publish data """ with self._lock: self._socket_pub.send_pyobj([topic, data]) def register(self, func: Callable) -> None: """ Register function """ self._functions[func.__name__] = func def check_heartbeat(self) -> None: """ Check whether it is required to send heartbeat. """ now: float = time() if now >= self._heartbeat_at: # Publish heartbeat self.publish(HEARTBEAT_TOPIC, now) # Update timestamp of next publish self._heartbeat_at = now + HEARTBEAT_INTERVAL
class Device(Actor): ''' The actor class implements all the management and control functions over its components ''' def __init__(self, gModel, gModelName, dName, sysArgv): ''' Constructor ''' self.logger = logging.getLogger(__name__) self.inst_ = self self.appName = gModel["name"] self.modelName = gModelName self.name = dName self.pid = os.getpid() self.suffix = "" if dName not in gModel["devices"]: raise BuildError('Device "%s" unknown' % dName) # In order to make the rest of the code work, we build an actor model for the device devModel = gModel["devices"][dName] self.model = {} # The made-up actor model formals = devModel[ "formals"] # Formals are the same as those of the device (component) self.model["formals"] = formals devInst = { "type": dName } # There is a single instance, containing the device component actuals = [] for arg in formals: name = arg["name"] actual = {} actual["name"] = name actual["param"] = name actuals.append(actual) devInst["actuals"] = actuals self.model["instances"] = {dName: devInst} self.model["locals"] = self.getMessageTypes( devModel) # All messages are local self.model["internals"] = {} # No internals self.INT_RE = re.compile(r"^[-]?\d+$") self.parseParams(sysArgv) # Use czmq's context czmq_ctx = Zsys.init() self.context = zmq.Context.shadow(czmq_ctx.value) Zsys.handler_reset() # Reset previous signal # Context for app sockets self.appContext = zmq.Context() if Config.SECURITY: (self.public_key, self.private_key) = zmq.auth.load_certificate(const.appCertFile) hosts = ['127.0.0.1'] try: with open(const.appDescFile, 'r') as f: content = yaml.load(f) hosts += content.hosts except: pass self.auth = ThreadAuthenticator(self.appContext) self.auth.start() self.auth.allow(*hosts) self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) else: (self.public_key, self.private_key) = (None, None) self.auth = None self.appContext = self.context try: if os.path.isfile(const.logConfFile) and os.access( const.logConfFile, os.R_OK): spdlog_setup.from_file(const.logConfFile) except Exception as e: self.logger.error("error while configuring componentLogger: %s" % repr(e)) messages = gModel[ "messages"] # Global message types (global on the network) self.messageNames = [] for messageSpec in messages: self.messageNames.append(messageSpec["name"]) locals_ = self.model[ "locals"] # Local message types (local to the host) self.localNames = [] for messageSpec in locals_: self.localNames.append(messageSpec["type"]) internals = self.model[ "internals"] # Internal message types (internal to the actor process) self.internalNames = [] for messageSpec in internals: self.internalNames.append(messageSpec["type"]) self.components = {} instSpecs = self.model["instances"] compSpecs = gModel["components"] devSpecs = gModel["devices"] for instName in instSpecs: # Create the component instances: the 'parts' instSpec = instSpecs[instName] instType = instSpec['type'] if instType in devSpecs: typeSpec = devSpecs[instType] else: raise BuildError( 'Device type "%s" for instance "%s" is undefined' % (instType, instName)) instFormals = typeSpec['formals'] instActuals = instSpec['actuals'] instArgs = self.buildInstArgs(instName, instFormals, instActuals) # Check whether the component is C++ component ccComponentFile = 'lib' + instType.lower() + '.so' ccComp = os.path.isfile(ccComponentFile) try: if ccComp: modObj = importlib.import_module('lib' + instType.lower()) self.components[instName] = modObj.create_component_py( self, self.model, typeSpec, instName, instType, instArgs, self.appName, self.name) else: self.components[instName] = Part(self, typeSpec, instName, instType, instArgs) except Exception as e: traceback.print_exc() self.logger.error("Error while constructing part '%s.%s': %s" % (instType, instName, str(e))) def getPortMessageTypes(self, ports, key, kinds, res): for _name, spec in ports[key].items(): for kind in kinds: typeName = spec[kind] res.append({"type": typeName}) def getMessageTypes(self, devModel): res = [] ports = devModel["ports"] self.getPortMessageTypes(ports, "pubs", ["type"], res) self.getPortMessageTypes(ports, "subs", ["type"], res) self.getPortMessageTypes(ports, "reqs", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "reps", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "clts", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "srvs", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "qrys", ["req_type", "rep_type"], res) self.getPortMessageTypes(ports, "anss", ["req_type", "rep_type"], res) return res # def getParameterValueType(self,param,defaultType): # paramValue, paramType = None, None # if defaultType != None: # if defaultType == str: # paramValue, paramType = param, str # elif defaultType == int: # paramValue, paramType = int(param),int # elif defaultType == float: # paramValue, paramType = float(param),float # elif defaultType == bool: # paramType = bool # paramValue = False if param == "False" else True if param == "True" else None # paramValue, paramType = bool(param),float # else: # if param == 'True': # paramValue, paramType = True, bool # elif param == 'False': # paramValue, paramType = True, bool # elif self.INT_RE.match(param) is not None: # paramValue, paramType = int(param),int # else: # try: # paramValue, paramType = float(param),float # except: # paramValue,paramType = str(param), str # return (paramValue,paramType) # def parseParams(self,sysArgv): # self.params = { } # formals = self.model["formals"] # optList = [] # for formal in formals: # key = formal["name"] # default = None if "default" not in formal else formal["default"] # self.params[key] = default # optList.append("%s=" % key) # try: # opts,args = getopt.getopt(sysArgv, '', optList) # except: # self.logger.info("Error parsing actor options %s" % str(sysArgv)) # return # # try: # for opt in opts: # optName2,optValue = opt # optName = optName2[2:] # Drop two leading dashes # if optName in self.params: # defaultType = None if self.params[optName] == None else type(self.params[optName]) # paramValue,paramType = self.getParameterValueType(optValue,defaultType) # if self.params[optName] != None: # if paramType != type(self.params[optName]): # raise BuildError("Type of default value does not match type of argument %s" # % str((optName,optValue))) # self.params[optName] = paramValue # else: # self.logger.info("Unknown argument %s - ignored" % optName) # for param in self.params: # if self.params[param] == None: # raise BuildError("Required parameter %s missing" % param) # def buildInstArgs(self,instName,formals,actuals): # args = {} # for formal in formals: # argName = formal['name'] # argValue = None # actual = next((actual for actual in actuals if actual['name'] == argName), None) # defaultValue = None # if 'default' in formal: # defaultValue = formal['default'] # if actual != None: # assert(actual['name'] == argName) # if 'param'in actual: # paramName = actual['param'] # if paramName in self.params: # argValue = self.params[paramName] # else: # raise BuildError("Unspecified parameter %s referenced in %s" # %(paramName,instName)) # elif 'value' in actual: # argValue = actual['value'] # else: # raise BuildError("Actual parameter %s has no value" % argName) # elif defaultValue != None: # argValue = defaultValue # else: # raise BuildError("Argument %s in %s has no defined value" % (argName,instName)) # args[argName] = argValue # return args # def isLocalMessage(self,msgTypeName): # ''' # Return True if the message type is local # ''' # return msgTypeName in self.localNames # # def getLocalIface(self): # ''' # Return the IP address of the host-local network interface (usually 127.0.0.1) # ''' # return self.localHost # # def getGlobalIface(self): # ''' # Return the IP address of the global network interface # ''' # return self.globalHost # # def setupIfaces(self): # ''' # Find the IP addresses of the (host-)local and network(-global) interfaces # ''' # (globalIPs,globalMACs,globalNames,localIP) = getNetworkInterfaces() # assert len(globalIPs) > 0 and len(globalMACs) > 0 # globalIP = globalIPs[0] # globalMAC = globalMACs[0] # self.localHost = localIP # self.globalHost = globalIP # self.macAddress = globalMAC # def setup(self): ''' Perform a setup operation on the actor (after the initial construction but before the activation of parts) ''' self.logger.info("setup") self.setupIfaces() self.suffix = self.macAddress self.disco = DiscoClient(self, self.suffix) self.disco.start() # Start the discovery service client self.disco.registerApp( ) # Register this actor with the discovery service self.logger.info("device registered with disco") self.deplc = DeplClient(self, self.suffix) self.deplc.start() ok = self.deplc.registerApp(isDevice=True) self.logger.info("device %s registered with depl" % ("is" if ok else "is not")) self.controls = {} self.controlMap = {} for inst in self.components: comp = self.components[inst] control = self.context.socket(zmq.PAIR) control.bind('inproc://part_' + inst + '_control') self.controls[inst] = control self.controlMap[id(control)] = comp if isinstance(comp, Part): self.components[inst].setup(control) else: self.components[inst].setup() # def start(self): # ''' # Start and operate the actor (infinite polling loop) # ''' # self.logger.info("starting") # self.discoChannel = self.disco.channel # Private channel to the discovery service # self.deplChannel = self.deplc.channel # # self.poller = zmq.Poller() # Set up the poller # self.poller.register(self.deplChannel,zmq.POLLIN) # self.poller.register(self.discoChannel,zmq.POLLIN) # # while 1: # sockets = dict(self.poller.poll()) # if self.discoChannel in sockets: # If there is a message from a service, handle it # msgs = self.recvChannelMessages(self.discoChannel) # for msg in msgs: # self.handleServiceUpdate(msg) # Handle message from disco service # del sockets[self.discoChannel] # elif self.deplChannel in sockets: # msgs = self.recvChannelMessages(self.deplChannel) # for msg in msgs: # self.handleDeplMessage(msg) # Handle message from depl service # del sockets[self.deplChannel] # else: # pass def terminate(self): self.logger.info("terminating") for component in self.components.values(): component.terminate() # self.devc.terminate() self.disco.terminate() # Clean up everything # self.context.destroy() time.sleep(1.0) self.logger.info("terminated") os._exit(0)
class ZmqConnector: context = None auth = None public_keys_dir = None secret_keys_dir = None puller = None publisher = None HOST = '' opponent_id = None available_player = 'XXX' # Client message protocol: # # 1. ID: Player ID (random string created by online broker) # 2. ACTION: status, command, recipient. # 3. MATCH: relevant player match data # Server message protocol: # # 1. Recipient: Player ID of recipient, used for filtering # 2. ACTION: sender (SERVER or opponent player's ID), command (welcome, wait, ready, play) # 3. Data: forwarded payload def __init__(self, host='127.0.0.1'): print("[zmq] Initializing ZMQ client object...") self.HOST = host self.context = zmq.Context() def setup(self): if not self.check_folder_structure(): return None else: self.server_auth() self.bind_pull() self.bind_pub() def check_folder_structure(self): keys_dir = os.path.join(os.getcwd(), '../certs') print(f"[#] checking folder structure: {keys_dir}") self.public_keys_dir = os.path.join( keys_dir, 'public') # has the public keys of registered clients self.secret_keys_dir = os.path.join( keys_dir, 'private') # has the server's private cert if not os.path.exists(keys_dir) \ and not os.path.exists(self.public_keys_dir) \ and not os.path.exists(self.secret_keys_dir): print("[!!] Certificates folders are missing") return False else: return True def server_auth(self): # Start an authenticator for this context print("[#] Starting authenticator...") self.auth = ThreadAuthenticator(self.context) self.auth.start() self.auth.allow(self.HOST) # give authenticator access to approved clients' certificate directory self.auth.configure_curve(domain='*', location=self.public_keys_dir) def bind_pull(self, port=5555): print("[zmq] Binding PULL socket : {}".format(port)) self.puller = self.context.socket(zmq.PULL) # feed certificates to socket server_secret_file = os.path.join(self.secret_keys_dir, "server.key_secret") self.puller.curve_publickey, self.puller.curve_secretkey = zmq.auth.load_certificate( server_secret_file) self.puller.curve_server = True # must come before bind self.puller.bind("tcp://*:{}".format(port)) def pull_receive_multi(self): try: # message = self.puller.recv_multipart(flags=zmq.DONTWAIT) message = self.puller.recv_multipart() print(f"[zmq] Received :\n\t{datetime.datetime.now()}- {message}") return message except zmq.Again as a: # print("[!zmq!] Error while getting messages: {}".format(a)) # print(traceback.format_exc()) return None except zmq.ZMQError as e: print("[!zmq!] Error while getting messages: {}".format(e)) print(traceback.format_exc()) return None def bind_pub(self, port=5556): print("[zmq] Binding PUB socket: {}".format(port)) self.publisher = self.context.socket(zmq.PUB) # feed own and approved certificates to socket server_secret_file = os.path.join(self.secret_keys_dir, "server.key_secret") self.publisher.curve_publickey, self.publisher.curve_secretkey = zmq.auth.load_certificate( server_secret_file) self.publisher.curve_server = True # must come before bind self.publisher.bind("tcp://*:{}".format(port)) def send(self, recipient, info, payload): message = list() message.append(recipient.encode()) message.append(json.dumps(info).encode()) message.append(json.dumps(payload).encode()) self.pub_send_multi(message) def pub_send_multi(self, message): try: self.publisher.send_multipart(message) print(f"[zmq] Sent :\n\t{datetime.datetime.now()}- {message}") except TypeError as e: print("[!zmq!] TypeError while sending message: {}".format(e)) print(traceback.format_exc()) except ValueError as e: print("[!zmq!] ValueError while sending message: {}".format(e)) print(traceback.format_exc()) except zmq.ZMQError as e: print("[!zmq!] ZMQError while sending message: {}".format(e)) print(traceback.format_exc()) # GENERIC FUNCTIONS def disconnect(self): print("[zmq] Disconnecting client...") for socket in (self.publisher, self.puller): if socket is not None: socket.close() self.context.term()
class Command(LAVADaemonCommand): """ worker_host is the hostname of the worker this field is set by the admin and could therefore be empty in a misconfigured instance. """ logger = None help = "LAVA dispatcher master" default_logfile = "/var/log/lava-server/lava-master.log" def __init__(self, *args, **options): super(Command, self).__init__(*args, **options) self.auth = None self.controler = None self.event_socket = None self.poller = None self.pipe_r = None self.inotify_fd = None # List of logs # List of known dispatchers. At startup do not load this from the # database. This will help to know if the slave as restarted or not. self.dispatchers = {"lava-logs": SlaveDispatcher("lava-logs", online=False)} self.events = {"canceling": set()} def add_arguments(self, parser): super(Command, self).add_arguments(parser) # Important: ensure share/env.yaml is put into /etc/ by setup.py in packaging. config = parser.add_argument_group("dispatcher config") config.add_argument('--env', default="/etc/lava-server/env.yaml", help="Environment variables for the dispatcher processes. " "Default: /etc/lava-server/env.yaml") config.add_argument('--env-dut', default="/etc/lava-server/env.dut.yaml", help="Environment variables for device under test. " "Default: /etc/lava-server/env.dut.yaml") config.add_argument('--dispatchers-config', default="/etc/lava-server/dispatcher.d", help="Directory that might contain dispatcher specific configuration") net = parser.add_argument_group("network") net.add_argument('--master-socket', default='tcp://*:5556', help="Socket for master-slave communication. Default: tcp://*:5556") net.add_argument('--event-url', default="tcp://localhost:5500", help="URL of the publisher") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument('--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def send_status(self, hostname): """ The master crashed, send a STATUS message to get the current state of jobs """ jobs = TestJob.objects.filter(actual_device__worker_host__hostname=hostname, state=TestJob.STATE_RUNNING) for job in jobs: self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname, job.actual_device.hostname) send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)]) def dispatcher_alive(self, hostname): if hostname not in self.dispatchers: # The server crashed: send a STATUS message self.logger.warning("Unknown dispatcher <%s> (server crashed)", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) self.send_status(hostname) # Mark the dispatcher as alive self.dispatchers[hostname].alive() def controler_socket(self): try: # We need here to use the zmq.NOBLOCK flag, otherwise we could block # the whole main loop where this function is called. msg = self.controler.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False # This is way to verbose for production and should only be activated # by (and for) developers # self.logger.debug("[CC] Receiving: %s", msg) # 1: the hostname (see ZMQ documentation) hostname = u(msg[0]) # 2: the action action = u(msg[1]) # Check that lava-logs only send PINGs if hostname == "lava-logs" and action != "PING": self.logger.error("%s => %s Invalid action from log daemon", hostname, action) return True # Handle the actions if action == 'HELLO' or action == 'HELLO_RETRY': self._handle_hello(hostname, action, msg) elif action == 'PING': self._handle_ping(hostname, action, msg) elif action == 'END': self._handle_end(hostname, action, msg) elif action == 'START_OK': self._handle_start_ok(hostname, action, msg) else: self.logger.error("<%s> sent unknown action=%s, args=(%s)", hostname, action, msg[1:]) return True def read_event_socket(self): try: msg = self.event_socket.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False try: (topic, _, dt, username, data) = (u(m) for m in msg) except ValueError: self.logger.error("Invalid event: %s", msg) return True if topic.endswith(".testjob"): try: data = simplejson.loads(data) if data["state"] == "Canceling": self.events["canceling"].add(int(data["job"])) except ValueError: self.logger.error("Invalid event data: %s", msg) return True def _handle_end(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) error_msg = msg[3] compressed_description = msg[4] except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) # ACK even if the job is unknown to let the dispatcher # forget about it send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) return filename = os.path.join(job.output_dir, 'description.yaml') # If description.yaml already exists: a END was already received if os.path.exists(filename): self.logger.info("[%d] %s => END (duplicated), skipping", job_id, hostname) else: if compressed_description: self.logger.info("[%d] %s => END", job_id, hostname) else: self.logger.info("[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)", job_id, hostname) with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(TestJob.HEALTH_INCOMPLETE) if error_msg: self.logger.error("[%d] Error: %s", job_id, error_msg) job.failure_comment = error_msg job.save() # Create description.yaml even if it's empty # Allows to know when END messages are duplicated try: # Create the directory if it was not already created mkdir(os.path.dirname(filename)) # TODO: check that compressed_description is not "" description = lzma.decompress(compressed_description) with open(filename, 'w') as f_description: f_description.write(description.decode("utf-8")) if description: parse_job_description(job) except (IOError, lzma.LZMAError) as exc: self.logger.error("[%d] Unable to dump 'description.yaml'", job_id) self.logger.exception("[%d] %s", job_id, exc) # ACK the job and mark the dispatcher as alive send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) self.dispatcher_alive(hostname) def _handle_hello(self, hostname, action, msg): # Check the protocol version try: slave_version = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("%s => %s", hostname, action) if slave_version != PROTOCOL_VERSION: self.logger.error("<%s> using protocol v%d while master is using v%d", hostname, slave_version, PROTOCOL_VERSION) return send_multipart_u(self.controler, [hostname, 'HELLO_OK']) # If the dispatcher is known and sent an HELLO, means that # the slave has restarted if hostname in self.dispatchers: if action == 'HELLO': self.logger.warning("Dispatcher <%s> has RESTARTED", hostname) else: # Assume the HELLO command was received, and the # action succeeded. self.logger.warning("Dispatcher <%s> was not confirmed", hostname) else: # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO # message. self.logger.warning("New dispatcher <%s>", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) # Mark the dispatcher as alive self.dispatcher_alive(hostname) def _handle_ping(self, hostname, action, msg): # pylint: disable=unused-argument self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL) # Send back a signal send_multipart_u(self.controler, [hostname, 'PONG', str(PING_INTERVAL)]) self.dispatcher_alive(hostname) def _handle_start_ok(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("[%d] %s => START_OK", job_id, hostname) try: with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_running() job.save() except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) else: self.dispatcher_alive(hostname) def export_definition(self, job): # pylint: disable=no-self-use job_def = yaml.load(job.definition) job_def['compatibility'] = job.pipeline_compatibility # no need for the dispatcher to retain comments return yaml.dump(job_def) def save_job_config(self, job, worker, device_cfg, options): output_dir = job.output_dir mkdir(output_dir) with open(os.path.join(output_dir, "job.yaml"), "w") as f_out: f_out.write(self.export_definition(job)) with contextlib.suppress(IOError): shutil.copy(options["env"], os.path.join(output_dir, "env.yaml")) with contextlib.suppress(IOError): shutil.copy(options["env_dut"], os.path.join(output_dir, "env.dut.yaml")) with contextlib.suppress(IOError): shutil.copy(os.path.join(options["dispatchers_config"], "%s.yaml" % worker.hostname), os.path.join(output_dir, "dispatcher.yaml")) with open(os.path.join(output_dir, "device.yaml"), "w") as f_out: yaml.dump(device_cfg, f_out) def start_job(self, job, options): # Load job definition to get the variables for template # rendering job_def = yaml.load(job.definition) job_ctx = job_def.get('context', {}) device = job.actual_device worker = device.worker_host # Load configurations env_str = load_optional_yaml_file(options['env']) env_dut_str = load_optional_yaml_file(options['env_dut']) device_cfg = device.load_configuration(job_ctx) dispatcher_cfg_file = os.path.join(options['dispatchers_config'], "%s.yaml" % worker.hostname) dispatcher_cfg = load_optional_yaml_file(dispatcher_cfg_file) self.save_job_config(job, worker, device_cfg, options) self.logger.info("[%d] START => %s (%s)", job.id, worker.hostname, device.hostname) send_multipart_u(self.controler, [worker.hostname, 'START', str(job.id), self.export_definition(job), yaml.dump(device_cfg), dispatcher_cfg, env_str, env_dut_str]) # For multinode jobs, start the dynamic connections parent = job for sub_job in job.sub_jobs_list: if sub_job == parent or not sub_job.dynamic_connection: continue # inherit only enough configuration for dynamic_connection operation self.logger.info("[%d] Trimming dynamic connection device configuration.", sub_job.id) min_device_cfg = parent.actual_device.minimise_configuration(device_cfg) self.save_job_config(sub_job, worker, min_device_cfg, options) self.logger.info("[%d] START => %s (connection)", sub_job.id, worker.hostname) send_multipart_u(self.controler, [worker.hostname, 'START', str(sub_job.id), self.export_definition(sub_job), yaml.dump(min_device_cfg), dispatcher_cfg, env_str, env_dut_str]) def start_jobs(self, options): """ Loop on all scheduled jobs and send the START message to the slave. """ # make the request atomic query = TestJob.objects.select_for_update() # Only select test job that are ready query = query.filter(state=TestJob.STATE_SCHEDULED) # Only start jobs on online workers query = query.filter(actual_device__worker_host__state=Worker.STATE_ONLINE) # exclude test job without a device: they are special test jobs like # dynamic connection. query = query.exclude(actual_device=None) # TODO: find a way to lock actual_device # Loop on all jobs for job in query: msg = None try: self.start_job(job, options) except jinja2.TemplateNotFound as exc: self.logger.error("[%d] Template not found: '%s'", job.id, exc.message) msg = "Template not found: '%s'" % exc.message except jinja2.TemplateSyntaxError as exc: self.logger.error("[%d] Template syntax error in '%s', line %d: %s", job.id, exc.name, exc.lineno, exc.message) msg = "Template syntax error in '%s', line %d: %s" % (exc.name, exc.lineno, exc.message) except IOError as exc: self.logger.error("[%d] Unable to read '%s': %s", job.id, exc.filename, exc.strerror) msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror) except yaml.YAMLError as exc: self.logger.error("[%d] Unable to parse job definition: %s", job.id, exc) msg = "Cannot parse job definition: %s" % exc if msg: # Add the error as lava.job result metadata = {"case": "job", "definition": "lava", "error_type": "Infrastructure", "error_msg": msg, "result": "fail"} suite, _ = TestSuite.objects.get_or_create(name="lava", job=job) TestCase.objects.create(name="job", suite=suite, result=TestCase.RESULT_FAIL, metadata=yaml.dump(metadata)) job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True) job.save() def cancel_jobs(self, partial=False): query = TestJob.objects.filter(state=TestJob.STATE_CANCELING) if partial: query = query.filter(id__in=list(self.events["canceling"])) for job in query: worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host self.logger.info("[%d] CANCEL => %s", job.id, worker.hostname) send_multipart_u(self.controler, [worker.hostname, 'CANCEL', str(job.id)]) def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-master", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return self.logger.info("[INIT] Marking all workers as offline") with transaction.atomic(): for worker in Worker.objects.select_for_update().all(): worker.go_state_offline() worker.save() # Create the sockets context = zmq.Context() self.controler = context.socket(zmq.ROUTER) self.event_socket = context.socket(zmq.SUB) if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.controler.setsockopt(zmq.IPV6, 1) self.event_socket.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate(options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except IOError as err: self.logger.error(err) self.auth.stop() return self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_server = True self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.controler.setsockopt(zmq.IDENTITY, b"master") # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42 # "If two clients use the same identity when connecting to a ROUTER # [...] the ROUTER socket shall hand-over the connection to the new # client and disconnect the existing one." self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1) self.controler.bind(options['master_socket']) self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC)) self.event_socket.connect(options['event_url']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.controler, zmq.POLLIN) self.poller.register(self.event_socket, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] LAVA master has started.") self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION) try: self.main_loop(options) except BaseException as exc: self.logger.error("[CLOSE] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Drop controler socket: the protocol does handle lost messages self.logger.info("[CLOSE] Closing the controler socket and dropping messages") self.controler.close(linger=0) self.event_socket.close(linger=0) if options['encrypt']: self.auth.stop() context.term() def main_loop(self, options): last_schedule = last_dispatcher_check = time.time() while True: try: try: # Compute the timeout now = time.time() timeout = min(SCHEDULE_INTERVAL - (now - last_schedule), PING_INTERVAL - (now - last_dispatcher_check)) # If some actions are remaining, decrease the timeout if self.events["canceling"]: timeout = min(timeout, 1) # Wait at least for 1ms timeout = max(timeout * 1000, 1) # Wait for data or a timeout sockets = dict(self.poller.poll(timeout)) except zmq.error.ZMQError: continue if sockets.get(self.pipe_r) == zmq.POLLIN: self.logger.info("[POLL] Received a signal, leaving") break # Command socket if sockets.get(self.controler) == zmq.POLLIN: while self.controler_socket(): # Unqueue all pending messages pass # Events socket if sockets.get(self.event_socket) == zmq.POLLIN: while self.read_event_socket(): # Unqueue all pending messages pass # Wait for the next iteration to handle the event. # In fact, the code that generated the event (lava-logs or # lava-server-gunicorn) needs some time to commit the # database transaction. # If we are too fast, the database object won't be # available (or in the right state) yet. continue # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) # Check dispatchers status now = time.time() if now - last_dispatcher_check > PING_INTERVAL: for hostname, dispatcher in self.dispatchers.items(): if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT: if hostname == "lava-logs": self.logger.error("[STATE] lava-logs goes OFFLINE") else: self.logger.error("[STATE] Dispatcher <%s> goes OFFLINE", hostname) self.dispatchers[hostname].go_offline() last_dispatcher_check = now # Limit accesses to the database. This will also limit the rate of # CANCEL and START messages if time.time() - last_schedule > SCHEDULE_INTERVAL: if self.dispatchers["lava-logs"].online: schedule(self.logger) # Dispatch scheduled jobs with transaction.atomic(): self.start_jobs(options) else: self.logger.warning("lava-logs is offline: can't schedule jobs") # Handle canceling jobs self.cancel_jobs() # Do not count the time taken to schedule jobs last_schedule = time.time() else: # Cancel the jobs and remove the jobs from the set if self.events["canceling"]: self.cancel_jobs(partial=True) self.events["canceling"] = set() except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset.") # Closing the database connection will force Django to reopen # the connection connection.close() time.sleep(2)
class FrankFancyStreamingInterface(object): """ Abstraction layer to the graph streamer as well as the central logger Uses direct (non encrypted) socket connection to the streaming server It uses an (encrypted) zeromq connection to the logger """ ConvertStatus = { "Cells": { 0: 5, #removing 1: 4, #allocating 2: 6 #blacklisting } } #TODO: give every scheduler an unique topic to easily distinguish between them on the queue def __init__(self, name, privatekey, VisualizerHost, root_id, ZeromqHost="*", empty=False): """ Calls internal methods to open the connections to both the Active Live visualizer and the logger :param VisualizerHost: The ip of the FrankFancyGraphStreamer :type VisualizerHost: str :param ZeromqHost: which interface the zeromq service needs to bind too ("*" for all interfaces) :type ZeromqHost: str :param KeyFolder: The folder with all the keys, as generated by generate_certificates.py :type KeyFolder: str :param root_id: the root of the network: LBR :type root_id: str :return: """ self.Active = None self.Logger = None self.EventId = 0 self.Name = name #used as topic on the queue if not empty: if privatekey is not None: self._connectLogger(privatekey, Host=ZeromqHost) if VisualizerHost is not None: self._connectVisualizer(VisualizerHost, root_id) self.g = DoDAG(root_id, root_id) self.root_id = root_id def _connectVisualizer(self, Host, root_id): """ Connect to the Active Live Visualizer :param Host: The ip of the FrankFancyGraphStreamer :param root_id: the ip6 address of the root node of the network :return: """ try: logg.debug("Connecting Streaming Interface to Active Viewer") self.Active = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.Active.connect((Host, 600)) logg.debug("Sending to Active Viewer:{}".format(root_id)) self.Active.sendall(root_id) except: logg.debug("Connection to Active Viewer failed!") self.Active = None def _connectLogger(self, key, Host="localhost"): """ Open a zeromq queue with publisher service :param Host: which interface the zeromq service needs to bind too ("*" for all interfaces) :param key: privatekey file of the scheduler :return: """ #TODO: error handling on certificates missing and stuff #TODO: expose more security options such as white/blacklisting ips and domain filtering self.context = zmq.Context() self.auth = ThreadAuthenticator(self.context) self.auth.start() self.auth.configure_curve(domain='*', location=os.path.join("keys", "public")) self.Logger = self.context.socket(zmq.PUB) scheduler_public, scheduler_secret = zmq.auth.load_certificate( os.path.join("keys", "plexi1.key_secret")) self.Logger.curve_secretkey = scheduler_secret self.Logger.curve_publickey = scheduler_public self.Logger.curve_server = True self.Logger.bind("tcp://127.0.0.1:6000") # raw_input("Press enter when the logger has opened subscription to us") def SendActiveJson(self, data): """ Sends an object as json encoded to the Active Live Viewer :param data: the object to be send :return: """ if self.Active is not None: logg.debug("Sending json data to Active: " + json.dumps(data)) self.Active.sendall(json.dumps(data)) def PublishLogging(self, LoggingName="zmq.auth", root_topic="zmq.auth"): """ Publishes the given python logger to the publishing service :param LoggingName: Name of the python logger service :type LoggingName: str :param root_topic: the topic given with message. is appended with .<LEVEL> :type root_topic: str :return: """ handler = PUBHandler(self.Logger) handler.root_topic = root_topic handler.formatters[logging.DEBUG] = logging.Formatter( fmt='%(asctime)s\t%(levelname)s: %(message)s', datefmt='%H:%M:%S') handler.formatters[logging.INFO] = logging.Formatter( fmt='%(asctime)s\t%(levelname)s: %(message)s', datefmt='%H:%M:%S') l = logging.getLogger(LoggingName) l.addHandler(handler) def ChangeCell(self, who, slotoffs, channeloffs, frame, ID, status): """ Notifies all active services about the changes to a cell in the schedule matrix :param who: The node in which the cell is changed :type who: :class: `node.NodeID` :param slotoffs: slot offset :param channeloffs: channel offset :param frame: frame name :param ID: local cell id :param status: new status of the cell :return: """ if self.Active is not None: logg.debug("Sending ChangeCell to active viewer") self.Active.sendall( json.dumps([ "changecell", { "who": str(who), "channeloffs": channeloffs, "slotoffs": slotoffs, "frame": frame, "id": ID, "status": status } ])) if self.Logger is not None: self.EventId += 1 logg.debug("Sending ChangeCell to logger, EventID:" + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : self.ConvertStatus["Cells"][status], # "InfoString" : json.dumps({"who": who, "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID}) # })]) self.Logger.send_multipart([ self.Name.encode(), pickle.dumps( Event( self.EventId, self.ConvertStatus["Cells"][status], time.time(), json.dumps({ "node_id": str(who), "channeloffs": channeloffs, "slotoffs": slotoffs, "frame": frame, "id": ID }))) ]) def DumpDotData(self, labels={}): """ dumps an entire dot file to the active viewer. This is not used for the logger :return: """ # packet = "[\"" + str(self.root_id) + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"," + json.dumps(dotdata) + "]" if self.Active is not None: logg.debug("Sending dotdata") # self.Active.sendall(bytearray("[\"" + root_id + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"," + dotdata + "]")) dotdata = self.g.draw_graph(labels=labels) self.Active.sendall( bytearray( json.dumps([ "\"" + self.root_id + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"", dotdata ]))) time.sleep(.5) def AddNode(self, node_id, parent): """ Sends a notification of joining node to the logger :param node_id: ip6 of the node :type node_id: str :param parent: ip6 of the parent node :type parent: str :return: """ node_id = str(node_id) parent = str(parent) if self.Logger is not None: self.EventId += 1 logg.debug("Sending Addnode to logger, EventID:" + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : 0, # "InfoString": json.dumps({"node_id" : node_id, "parent" : parent}) # })]) self.Logger.send_multipart([ self.Name.encode(), pickle.dumps( Event( self.EventId, 0, time.time(), json.dumps({ "node_id": str(node_id), "parent": str(parent) }))) ]) if self.Active is not None: logg.debug( "Sending Addnode to Active Visualizer, node:{}, parent:{}". format(node_id, parent)) if parent == "root": self.g.attach_node(node_id) else: self.g.attach_child(node_id, parent) self.DumpDotData() def RewireNode(self, node_id, old_parent, new_parent): """ Notifies the logger of a rewire that happened in the network :param node_id: ip6 of the node that has rewired :param old_parent: ip6 of the old parent :param new_parent: ip6 of the new parent :return: """ node_id = str(node_id) old_parent = str(old_parent) new_parent = str(new_parent) if self.Logger is not None: self.EventId += 1 logg.debug("Sending RewireNode to logger, EventID: " + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : 2, # "InfoString": json.dumps({"node_id" : node_id, "old_parent" : old_parent, "new_parent" : new_parent}) # })]) self.Logger.send_multipart([ self.Name.encode(), pickle.dumps( Event( self.EventId, 2, time.time(), json.dumps({ "node_id": str(node_id), "old_parent": str(old_parent), "new_parent": str(new_parent) }))) ]) if self.Active is not None: logg.debug("Sending Rewire to the Active") self.g.attach_child(node_id, new_parent) self.DumpDotData() def RemoveNode(self, node_id): """ Notifies the logger of a disconnected node :param node_id: ip6 of the node that has disconnected :return: """ node_id = str(node_id) if self.Logger is not None: self.EventId += 1 logg.debug("Sending RemoveNode to logger, EventID: " + str(self.EventId)) # self.Logger.send_multipart([self.Name.encode(), pickle.dumps({ # "EventId" : self.EventId, # "SubjectId" : 1, # "InfoString": json.dumps({"node_id" : node_id}) # })]) self.Logger.send_multipart([ self.Name.encode(), pickle.dumps( Event(self.EventId, 1, time.time(), json.dumps({"node_id": str(node_id)}))) ]) if self.Active is not None: self.g.detach_node(node_id) self.DumpDotData() def RegisterFrame(self, num_cells, framename): """ Notifies the logger of a new frame that is defined in the scheduler algorithm :param num_cells: number of cells per channel :param framename: unique identifieng name :return: """ if self.Logger is not None: self.EventId += 1 logg.debug("Sending RegisterFrame to logger, EventID: " + str(self.EventId)) self.Logger.send_multipart([ self.Name.encode(), pickle.dumps( Event(self.EventId, 7, time.time(), json.dumps({ "cells": num_cells, "name": framename }))) ]) def RegisterFrames(self, frames): if self.Active is not None: logg.debug("Sending RegisterFrames to Active") self.Active.sendall(bytearray(json.dumps(frames)))
class RpcServer: """""" def __init__(self): """ Constructor """ # Save functions dict: key is fuction name, value is fuction object self.__functions: Dict[str, Any] = {} # Zmq port related self.__context: zmq.Context = zmq.Context() # Reply socket (Request–reply pattern) self.__socket_rep: zmq.Socket = self.__context.socket(zmq.REP) # Publish socket (Publish–subscribe pattern) self.__socket_pub: zmq.Socket = self.__context.socket(zmq.PUB) # Worker thread related self.__active: bool = False # RpcServer status self.__thread: threading.Thread = None # RpcServer thread self.__lock: threading.Lock = threading.Lock() # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None def is_active(self) -> bool: """""" return self.__active def start(self, rep_address: str, pub_address: str, server_secretkey_path: str = "", username: str = "", password: str = "") -> None: """ Start RpcServer """ if self.__active: return # Start authenticator if server_secretkey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( server_secretkey_path) self.__socket_pub.curve_secretkey = secretkey self.__socket_pub.curve_publickey = publickey self.__socket_pub.curve_server = True self.__socket_rep.curve_secretkey = secretkey self.__socket_rep.curve_publickey = publickey self.__socket_rep.curve_server = True elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_pub.plain_server = True self.__socket_rep.plain_server = True # Bind socket address self.__socket_rep.bind(rep_address) self.__socket_pub.bind(pub_address) # Start RpcServer status self.__active = True # Start RpcServer thread self.__thread = threading.Thread(target=self.run) self.__thread.start() def stop(self) -> None: """ Stop RpcServer """ if not self.__active: return # Stop RpcServer status self.__active = False def join(self) -> None: # Wait for RpcServer thread to exit if self.__thread and self.__thread.is_alive(): self.__thread.join() self.__thread = None def run(self) -> None: """ Run RpcServer functions """ start = datetime.utcnow() while self.__active: # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) cur = datetime.utcnow() delta = cur - start if delta >= KEEP_ALIVE_INTERVAL: self.publish(KEEP_ALIVE_TOPIC, cur) if not self.__socket_rep.poll(1000): continue # Receive request data from Reply socket req = self.__socket_rep.recv_pyobj() # Get function name and parameters name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails try: func = self.__functions[name] r = func(*args, **kwargs) rep = [True, r] except Exception as e: # noqa rep = [False, traceback.format_exc()] # send callable response by Reply socket self.__socket_rep.send_pyobj(rep) # Unbind socket address self.__socket_pub.unbind(self.__socket_pub.LAST_ENDPOINT) self.__socket_rep.unbind(self.__socket_rep.LAST_ENDPOINT) def publish(self, topic: str, data: Any) -> None: """ Publish data """ with self.__lock: self.__socket_pub.send_pyobj([topic, data]) def register(self, func: Callable) -> None: """ Register function """ self.__functions[func.__name__] = func
def run(self): self.set_status("Client Startup") self.set_status("Creating zmq Contexts",1) clientctx = zmq.Context() self.set_status("Starting zmq ThreadedAuthenticator",1) #clientauth = zmq.auth.ThreadedAuthenticator(clientctx) clientauth = ThreadAuthenticator(clientctx) clientauth.start() with taco.globals.settings_lock: publicdir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/public/")) privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/private/")) self.set_status("Configuring Curve to use publickey dir:" + publicdir) clientauth.configure_curve(domain='*', location=publicdir) poller = zmq.Poller() while not self.stop.is_set(): #logging.debug("PRE") result = self.sleep.wait(0.1) #logging.debug(result) self.sleep.clear() if self.stop.is_set(): break if abs(time.time() - self.connect_block_time) > 1: with taco.globals.settings_lock: self.max_upload_rate = taco.globals.settings["Upload Limit"] * taco.constants.KB with taco.globals.settings_lock: self.max_download_rate = taco.globals.settings["Download Limit"] * taco.constants.KB self.chunk_request_rate = float(taco.constants.FILESYSTEM_CHUNK_SIZE) / float(self.max_download_rate) #logging.debug(str((self.max_download_rate,taco.constants.FILESYSTEM_CHUNK_SIZE,self.chunk_request_rate))) self.connect_block_time = time.time() with taco.globals.settings_lock: for peer_uuid in taco.globals.settings["Peers"].keys(): if taco.globals.settings["Peers"][peer_uuid]["enabled"]: #init some defaults if not peer_uuid in self.client_reconnect_mod: self.client_reconnect_mod[peer_uuid] = taco.constants.CLIENT_RECONNECT_MIN if not peer_uuid in self.client_connect_time: self.client_connect_time[peer_uuid] = time.time() + self.client_reconnect_mod[peer_uuid] if not peer_uuid in self.client_timeout: self.client_timeout[peer_uuid] = time.time() + taco.constants.ROLLCALL_TIMEOUT if time.time() >= self.client_connect_time[peer_uuid]: if peer_uuid not in self.clients.keys(): self.set_status("Starting Client for: " + peer_uuid) try: ip_of_client = socket.gethostbyname(taco.globals.settings["Peers"][peer_uuid]["hostname"]) except: self.set_status("Starting of client failed due to bad dns lookup:" + peer_uuid) continue self.clients[peer_uuid] = clientctx.socket(zmq.DEALER) self.clients[peer_uuid].setsockopt(zmq.LINGER, 0) client_public, client_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-client.key_secret"))) self.clients[peer_uuid].curve_secretkey = client_secret self.clients[peer_uuid].curve_publickey = client_public self.clients[peer_uuid].curve_serverkey = str(taco.globals.settings["Peers"][peer_uuid]["serverkey"]) self.clients[peer_uuid].connect("tcp://" + ip_of_client + ":" + str(taco.globals.settings["Peers"][peer_uuid]["port"])) self.next_rollcall[peer_uuid] = time.time() with taco.globals.high_priority_output_queue_lock: taco.globals.high_priority_output_queue[peer_uuid] = Queue.Queue() with taco.globals.medium_priority_output_queue_lock: taco.globals.medium_priority_output_queue[peer_uuid] = Queue.Queue() with taco.globals.low_priority_output_queue_lock: taco.globals.low_priority_output_queue[peer_uuid] = Queue.Queue() with taco.globals.file_request_output_queue_lock: taco.globals.file_request_output_queue[peer_uuid] = Queue.Queue() poller.register(self.clients[peer_uuid],zmq.POLLIN) if len(self.clients.keys()) == 0: continue peer_keys = self.clients.keys() random.shuffle(peer_keys) for peer_uuid in peer_keys: #self.set_status("Socket Write Possible:" + peer_uuid) #high priority queue processing with taco.globals.high_priority_output_queue_lock: while not taco.globals.high_priority_output_queue[peer_uuid].empty(): self.set_status("high priority output q not empty:" + peer_uuid) data = taco.globals.high_priority_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #medium priority queue processing with taco.globals.medium_priority_output_queue_lock: while not taco.globals.medium_priority_output_queue[peer_uuid].empty(): self.set_status("medium priority output q not empty:" + peer_uuid) data = taco.globals.medium_priority_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #filereq q, aka the download throttle if time.time() >= self.file_request_time: self.file_request_time = time.time() with taco.globals.file_request_output_queue_lock: if not taco.globals.file_request_output_queue[peer_uuid].empty(): with taco.globals.download_limiter_lock: download_rate = taco.globals.download_limiter.get_rate() bw_percent = download_rate / self.max_download_rate wait_time = self.chunk_request_rate * bw_percent #self.set_status(str((download_rate,self.max_download_rate,self.chunk_request_rate,bw_percent,wait_time))) if wait_time > 0.01: self.file_request_time += wait_time if download_rate < self.max_download_rate: self.set_status("filereq output q not empty+free bw:" + peer_uuid) data = taco.globals.file_request_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #low priority queue processing with taco.globals.low_priority_output_queue_lock: if not taco.globals.low_priority_output_queue[peer_uuid].empty(): with taco.globals.upload_limiter_lock: upload_rate = taco.globals.upload_limiter.get_rate() if upload_rate < self.max_upload_rate: self.set_status("low priority output q not empty+free bw:" + peer_uuid) data = taco.globals.low_priority_output_queue[peer_uuid].get() self.clients[peer_uuid].send_multipart(['',data]) self.sleep.set() with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) #rollcall special case if self.next_rollcall[peer_uuid] < time.time(): #self.set_status("Requesting Rollcall from: " + peer_uuid) data = taco.commands.Request_Rollcall() self.clients[peer_uuid].send_multipart(['',data]) with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data)) self.next_rollcall[peer_uuid] = time.time() + random.randint(taco.constants.ROLLCALL_MIN,taco.constants.ROLLCALL_MAX) self.sleep.set() #continue #RECEIVE BLOCK socks = dict(poller.poll(0)) while self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLIN: #self.set_status("Socket Read Possible") sink,data = self.clients[peer_uuid].recv_multipart() with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data)) self.set_client_last_reply(peer_uuid) self.next_request = taco.commands.Process_Reply(peer_uuid,data) if self.next_request != "": with taco.globals.medium_priority_output_queue_lock: taco.globals.medium_priority_output_queue[peer_uuid].put(self.next_request) self.sleep.set() socks = dict(poller.poll(0)) #cleanup block self.error_msg = [] if self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLERR: self.error_msg.append("got a socket error") if abs(self.client_timeout[peer_uuid] - time.time()) > taco.constants.ROLLCALL_TIMEOUT: self.error_msg.append("havn't seen communications") if len(self.error_msg) > 0: self.set_status("Stopping client: " + peer_uuid + " -- " + " and ".join(self.error_msg),2) poller.unregister(self.clients[peer_uuid]) self.clients[peer_uuid].close(0) del self.clients[peer_uuid] del self.client_timeout[peer_uuid] with taco.globals.high_priority_output_queue_lock: del taco.globals.high_priority_output_queue[peer_uuid] with taco.globals.medium_priority_output_queue_lock: del taco.globals.medium_priority_output_queue[peer_uuid] with taco.globals.low_priority_output_queue_lock: del taco.globals.low_priority_output_queue[peer_uuid] with taco.globals.file_request_output_queue_lock: del taco.globals.file_request_output_queue[peer_uuid] self.client_reconnect_mod[peer_uuid] = min(self.client_reconnect_mod[peer_uuid] + taco.constants.CLIENT_RECONNECT_MOD,taco.constants.CLIENT_RECONNECT_MAX) self.client_connect_time[peer_uuid] = time.time() + self.client_reconnect_mod[peer_uuid] self.set_status("Terminating Clients") for peer_uuid in self.clients.keys(): self.clients[peer_uuid].close(0) self.set_status("Stopping zmq ThreadedAuthenticator") clientauth.stop() clientctx.term() self.set_status("Clients Exit")
def run(self): self.set_status("Server Startup") self.set_status("Creating zmq Contexts",1) serverctx = zmq.Context() self.set_status("Starting zmq ThreadedAuthenticator",1) #serverauth = zmq.auth.ThreadedAuthenticator(serverctx) serverauth = ThreadAuthenticator(serverctx) serverauth.start() with taco.globals.settings_lock: bindip = taco.globals.settings["Application IP"] bindport = taco.globals.settings["Application Port"] localuuid = taco.globals.settings["Local UUID"] publicdir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/public/")) privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/" + taco.globals.settings["Local UUID"] + "/private/")) self.set_status("Configuring Curve to use publickey dir:" + publicdir) serverauth.configure_curve(domain='*', location=publicdir) #auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) self.set_status("Creating Server Context",1) server = serverctx.socket(zmq.REP) server.setsockopt(zmq.LINGER, 0) self.set_status("Loading Server Certs",1) server_public, server_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-server.key_secret"))) server.curve_secretkey = server_secret server.curve_publickey = server_public server.curve_server = True if bindip == "0.0.0.0": bindip ="*" self.set_status("Server is now listening for encrypted ZMQ connections @ "+ "tcp://" + bindip +":" + str(bindport)) server.bind("tcp://" + bindip +":" + str(bindport)) poller = zmq.Poller() poller.register(server, zmq.POLLIN|zmq.POLLOUT) while not self.stop.is_set(): socks = dict(poller.poll(200)) if server in socks and socks[server] == zmq.POLLIN: #self.set_status("Getting a request") data = server.recv() with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data)) (client_uuid,reply) = taco.commands.Proccess_Request(data) if client_uuid!="0": self.set_client_last_request(client_uuid) socks = dict(poller.poll(10)) if server in socks and socks[server] == zmq.POLLOUT: #self.set_status("Replying to a request") with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(reply)) server.send(reply) self.set_status("Stopping zmq server with 0 second linger") server.close(0) self.set_status("Stopping zmq ThreadedAuthenticator") serverauth.stop() serverctx.term() self.set_status("Server Exit")
class Facilities(furnace_runtime.FurnaceRuntime): """ Main Facilities class. """ def __init__(self, debug, u2, kp, qt, it, ip, fac): """ Constructor, ZMQ connections to the APP partition and backend are built here. :param u2: The UUID for the corresponding app instance. :param kp: The keypair to use, in the form {'be_key': '[path]', 'app_key': '[path]'} :param qt: Custom quotas dict. :param ip: Connect to this proxy IP. :param it: Connect to this proxy TCP port. :param fac: The path to the FAC partition's ZMQ socket. """ super(Facilities, self).__init__(debug=debug) self.debug = debug self.u2 = u2 self.fac_path = fac self.it = it self.ip = ip self.already_shutdown = False self.tprint(DEBUG, "starting Facilities constructor") self.tprint(DEBUG, f"running as UUID={u2}") api = api_implementation._default_implementation_type if api == "python": raise Exception("not running in cpp accelerated mode") self.tprint(DEBUG, f"running in api mode: {api}") self.fmsg_in = facilities_pb2.FacMessage() self.fmsg_out = facilities_pb2.FacMessage() self.pmsg_in = facilities_pb2.FacMessage() self.pmsg_out = facilities_pb2.FacMessage() self.storage = {} self.io_id = 0 self.disk_usage = 0 self.max_disk_usage = int(qt["_furnace_max_disk"]) self.tprint(DEBUG, f"MAX_DISK_USAGE {self.max_disk_usage}") self.context = zmq.Context(io_threads=4) self.poller = zmq.Poller() # connections with the proxy self.tprint(DEBUG, f"loading crypto...") self.auth = ThreadAuthenticator(self.context) # crypto bootstrap self.auth.start() self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publisher_public, _ = zmq.auth.load_certificate(kp["be_pub"]) public, secret = zmq.auth.load_certificate(kp["app_key"]) assert self.auth.is_alive() self.tprint(DEBUG, f"loading crypto... DONE") TCP_PXY_RTR = f"tcp://{str(ip)}:{int(it)+0}" TCP_PXY_PUB = f"tcp://{str(ip)}:{int(it)+1}" # connections to the proxy s = f"tcp://{self.ip}:{self.it}" self.dealer_pxy = self.context.socket(zmq.DEALER) # async directed self.didentity = f"{u2}-d" # this must match the FE self.dealer_pxy.identity = self.didentity.encode() self.dealer_pxy.curve_publickey = public self.dealer_pxy.curve_secretkey = secret self.dealer_pxy.curve_serverkey = publisher_public self.tprint(INFO, f"FAC--PXY: Connecting as Dealer to {TCP_PXY_RTR}") self.dealer_pxy.connect(s) self.req_pxy = self.context.socket(zmq.REQ) # sync self.ridentity = f"{u2}-r" self.req_pxy.identity = self.ridentity.encode() self.req_pxy.curve_publickey = public self.req_pxy.curve_secretkey = secret self.req_pxy.curve_serverkey = publisher_public self.req_pxy.connect(s) self.sub_pxy = self.context.socket(zmq.SUB) # async broadcast self.sub_pxy.curve_publickey = public self.sub_pxy.curve_secretkey = secret self.sub_pxy.curve_serverkey = publisher_public self.sub_pxy.setsockopt(zmq.SUBSCRIBE, b"") self.tprint(INFO, f"FAC--PXY: Connecting as Publisher to {TCP_PXY_PUB}") self.sub_pxy.connect(TCP_PXY_PUB) self.poller.register(self.dealer_pxy, zmq.POLLIN) self.poller.register(self.sub_pxy, zmq.POLLIN) # connections with the frontend s = f"ipc://{self.fac_path}{os.sep}fac_rtr" self.rtr_fac = self.context.socket(zmq.ROUTER) # fe sync/async msgs self.tprint(INFO, f"APP--FAC: Binding as Router to {s}") try: self.rtr_fac.bind(s) except zmq.error.ZMQError: self.tprint(ERROR, f"rtr_fac {s} is already in use!") sys.exit() s = f"ipc://{self.fac_path}{os.sep}fac_pub" self.pub_fac = self.context.socket(zmq.PUB) # bcast to fe self.tprint(INFO, f"APP--FAC: Binding as Publisher to {s}") try: self.pub_fac.bind(s) except zmq.error.ZMQError: self.tprint(ERROR, f"pub_fac {s} is already in use!") sys.exit() self.poller.register(self.rtr_fac, zmq.POLLIN) self.tprint(DEBUG, "leaving Facilities constructor") def loop(self): """ Main event loop. Polls (with timeout) on ZMQ sockets waiting for data from the APP partition and backend. :returns: Nothing. """ while True: socks = dict(self.poller.poll(TIMEOUT_FAC)) # inbound async from BE (sync should not arrive here) if self.dealer_pxy in socks: sys.stdout.write("v") sync = ASYNC raw_msg = self.dealer_pxy.recv() self.msgin += 1 raw_out = self.be2fe_process_protobuf(raw_msg, sync) # inbound broadcast from BE if self.sub_pxy in socks: sys.stdout.write("v") raw_msg = self.sub_pxy.recv() self.msgin += 1 self.pmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "BROADCAST FROM PXY\n%s" % (self.pmsg_in, )) index = 0 submsg_list = getattr(self.pmsg_in, "be_msg") m = submsg_list[index] self.pmsg_out.Clear() msgnum = self.pmsg_out.__getattribute__("BE_MSG") self.pmsg_out.type.append(msgnum) submsg = getattr(self.pmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.pmsg_out.SerializeToString() self.tprint(DEBUG, "to FE {len(raw_str)}B {self.pmsg_out}") self.pub_fac.send(raw_str) self.msgout += 1 # inbound from FE if self.rtr_fac in socks: sys.stdout.write("^") pkt = self.rtr_fac.recv_multipart() self.msgin += 1 if len(pkt) == 3: # sync sync = SYNC ident, empty, raw_msg = pkt raw_out = self.fe2be_process_protobuf(raw_msg, sync) self.rtr_fac.send_multipart( [ident.encode(), b"", raw_out.encode()]) self.msgout += 1 elif len(pkt) == 2: # async sync = ASYNC ident, raw_msg = pkt self.fe2be_process_protobuf(raw_msg, sync) if not socks: sys.stdout.write(".") sys.stdout.flush() self.tick += 1 def be2fe_process_protobuf(self, raw_msg, sync): """ Processes messages sent by the backend. :returns: Nothing (modifies class's protobufs). """ self.pmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "FROM BE\n%s" % (self.pmsg_in, )) for t in self.pmsg_in.type: msgtype = facilities_pb2.FacMessage.Type.Name( t) # msgtype = BE_MSG msglist = getattr(self.pmsg_in, msgtype.lower()) m = msglist.pop(0) # the particular message if msgtype == "BE_MSG": self.fmsg_out.Clear() msgnum = self.fmsg_out.__getattribute__("BE_MSG") self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.fmsg_out.SerializeToString() self.tprint( DEBUG, "to FE {self.didentity} {len(raw_str)}B {self.fmsg_out}") self.rtr_fac.send_multipart([self.didentity.encode(), raw_str]) self.msgout += 1 return # Split off between pass-through to BE or fac-provided IO. def fe2be_process_protobuf(self, raw_msg, sync): """ Processes messages sent by the app. Some messages are processed locally (IO). Others are forwarded to the backend. :returns: Raw serialized protobuf. """ self.fmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "FROM FE\n%s" % (self.fmsg_in, )) self.fmsg_out.Clear() for t in self.fmsg_in.type: msgtype = facilities_pb2.FacMessage.Type.Name( t) # msgtype = BE_MSG msglist = getattr(self.fmsg_in, msgtype.lower()) m = msglist.pop(0) # the particular message if msgtype == "BE_MSG": # Pass-through to BE self.fe2be_process_msg(m, sync) elif msgtype == "GET": # IO self.process_get(m) elif msgtype == "SET": # IO self.process_set(m) self.tprint(DEBUG, "sending {self.fmsg_out}") raw_out = self.fmsg_out.SerializeToString() return raw_out def fe2be_process_msg(self, m, sync): """ Processes messages sent by the app. All these are destined to the backend. :returns: Nothing (modifies class's protobufs). """ self.pmsg_out.Clear() msgnum = self.pmsg_out.__getattribute__("BE_MSG") self.pmsg_out.type.append(msgnum) submsg = getattr(self.pmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.pmsg_out.SerializeToString() self.tprint(DEBUG, "sending to BE %dB %s" % (len(raw_str), self.pmsg_out)) if sync == SYNC: self.req_pxy.send(raw_str) self.msgout += 1 self.pmsg_in.ParseFromString(self.req_pxy.recv()) self.msgin += 1 self.tprint(DEBUG, "recv from BE %dB %s" % (len(raw_str), self.pmsg_in)) index = 0 submsg_list = getattr(self.pmsg_in, "be_msg_ret") status = submsg_list[index].status val = submsg_list[index].value self.fmsg_out.Clear() msgnum = self.fmsg_out.__getattribute__("BE_MSG_RET") self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, "be_msg_ret").add() submsg.status = status submsg.value = val # raw_str = self.pmsg_out.SerializeToString() # return raw_str return elif sync == ASYNC: self.dealer_pxy.send(raw_str) self.msgout += 1 return def process_get(self, m): """ The app is asking for previously stored data. :param m: The protobuf message sent by the app. :returns: Nothing (modifies class's protobufs). """ key = str(m.key) msgtype = "GET_RET" msgnum = self.fmsg_out.__getattribute__(msgtype) self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, msgtype.lower()).add() submsg.key = key if not self.get_name_ok(key): submsg.value = "" submsg.result = VMI_FAILURE return try: io_id = self.storage[key] with open(FAC_IO_DIR + str(io_id), "rb") as fh: submsg.value = fh.read().decode() submsg.result = VMI_SUCCESS except IndexError: submsg.value = "" submsg.result = VMI_FAILURE return def process_set(self, m): """ The app has requested we store this data. :param m: The protobuf message sent by the app. :returns: Nothing (modifies class's protobufs). """ key = str(m.key) value = str(m.value) msgtype = "SET_RET" msgnum = self.fmsg_out.__getattribute__(msgtype) self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, msgtype.lower()).add() submsg.key = key if not self.get_name_ok(key): self.tprint(ERROR, f"key not allowed") submsg.value = "" submsg.result = VMI_FAILURE return self.tprint( DEBUG, f"cur={self.disk_usage}, req={len(value.encode())}, max={MAX_DISK_USAGE}", ) if self.disk_usage + len(value.encode()) > MAX_DISK_USAGE: self.tprint(ERROR, f"disk usage exceeded") submsg.result = VMI_FAILURE return try: # io_id = str(uuid.uuid4()) self.tprint(DEBUG, f"checking for existing key {key}") # overwrite a key try: existing_id = self.storage[key] self.tprint(DEBUG, f"found!") except KeyError: self.tprint(DEBUG, f"not found") pass else: s = os.path.getsize(FAC_IO_DIR + str(existing_id)) self.tprint(DEBUG, f"deleting existing file of size={s}") self.disk_usage -= s os.remove(FAC_IO_DIR + str(existing_id)) del self.storage[key] io_id = self.io_id self.io_id += 1 self.tprint(DEBUG, f"will assign key {io_id}") self.tprint(DEBUG, f"writing file of size {len(value.encode())}") self.disk_usage += len(value) self.tprint(DEBUG, f"disk usage at {self.disk_usage}") self.storage[key] = io_id with open(FAC_IO_DIR + str(io_id), "wb") as fh: fh.write(value.encode()) submsg.result = VMI_SUCCESS except: submsg.result = VMI_FAILURE return def get_name_ok(self, n): """ :param n: The key (must be alphanum and <=16 characters). :returns: True if in accordance with naming convention, False if otherwise. """ return 0 < len(n) < 17 and re.match("^[\w-]+$", n) def check_encoding(self, val): """ Fun. Now with more Unicode. :param val: Is this variable a string? :returns: Nothing. :raises: TypeError if val is not a string. """ if not isinstance(val, str): raise TypeError("Error: value is not of type str") def shutdown(self): """ Exit cleanly. :returns: Nothing. """ if self.already_shutdown: self.tprint(ERROR, "already shutdown?") return self.already_shutdown = True for key in self.storage: io_id = self.storage[key] self.tprint(INFO, "removing io: %s -> %s" % (key, io_id)) os.remove(FAC_IO_DIR + str(io_id)) self.auth.stop() self.poller.unregister(self.dealer_pxy) self.poller.unregister(self.sub_pxy) self.dealer_pxy.close() self.sub_pxy.close() self.poller.unregister(self.rtr_fac) self.rtr_fac.close() self.pub_fac.close() self.context.destroy() super(Facilities, self).shutdown()
class Actor(object): '''The actor class implements all the management and control functions over its components :param gModel: the JSON-based dictionary holding the model for the app this actor belongs to. :type gModel: dict :param gModelName: the name of the top-level model for the app :type gModelName: str :param aName: name of the actor. It is an index into the gModel that points to the part of the model specific to the actor :type aName: str :param sysArgv: list of arguments for the actor: -key1 value1 -key2 value2 ... :type list: ''' def __init__(self, gModel, gModelName, aName, sysArgv): ''' Constructor ''' self.logger = logging.getLogger(__name__) self.inst_ = self self.appName = gModel["name"] self.modelName = gModelName self.name = aName self.pid = os.getpid() self.uuid = None self.setupIfaces() # Assumption : pid is a 4 byte int self.actorID = ipaddress.IPv4Address( self.globalHost).packed + self.pid.to_bytes(4, 'big') self.suffix = "" if aName not in gModel["actors"]: raise BuildError('Actor "%s" unknown' % aName) self.model = gModel["actors"][ aName] # Fetch the relevant content from the model self.INT_RE = re.compile(r"^[-]?\d+$") self.parseParams(sysArgv) # Use czmq's context czmq_ctx = Zsys.init() self.context = zmq.Context.shadow(czmq_ctx.value) Zsys.handler_reset() # Reset previous signal handler # Context for app sockets self.appContext = zmq.Context() if Config.SECURITY: (self.public_key, self.private_key) = zmq.auth.load_certificate(const.appCertFile) hosts = ['127.0.0.1'] try: with open(const.appDescFile, 'r') as f: content = yaml.load(f) hosts += content.hosts except: pass self.auth = ThreadAuthenticator(self.appContext) self.auth.start() self.auth.allow(*hosts) self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) else: (self.public_key, self.private_key) = (None, None) self.auth = None self.appContext = self.context try: if os.path.isfile(const.logConfFile) and os.access( const.logConfFile, os.R_OK): spdlog_setup.from_file(const.logConfFile) except Exception as e: self.logger.error("error while configuring componentLogger: %s" % repr(e)) messages = gModel[ "messages"] # Global message types (global on the network) self.messageNames = [] for messageSpec in messages: self.messageNames.append(messageSpec["name"]) locals_ = self.model[ "locals"] # Local message types (local to the host) self.localNames = [] for messageSpec in locals_: self.localNames.append(messageSpec["type"]) internals = self.model[ "internals"] # Internal message types (internal to the actor process) self.internalNames = [] for messageSpec in internals: self.internalNames.append(messageSpec["type"]) self.components = {} instSpecs = self.model["instances"] compSpecs = gModel["components"] ioSpecs = gModel["devices"] for instName in instSpecs: # Create the component instances: the 'parts' instSpec = instSpecs[instName] instType = instSpec['type'] if instType in compSpecs: typeSpec = compSpecs[instType] ioComp = False elif instType in ioSpecs: typeSpec = ioSpecs[instType] ioComp = True else: raise BuildError( 'Component type "%s" for instance "%s" is undefined' % (instType, instName)) instFormals = typeSpec['formals'] instActuals = instSpec['actuals'] instArgs = self.buildInstArgs(instName, instFormals, instActuals) # Check whether the component is C++ component ccComponentFile = 'lib' + instType.lower() + '.so' ccComp = os.path.isfile(ccComponentFile) try: if not ioComp: if ccComp: modObj = importlib.import_module('lib' + instType.lower()) self.components[instName] = modObj.create_component_py( self, self.model, typeSpec, instName, instType, instArgs, self.appName, self.name) else: self.components[instName] = Part( self, typeSpec, instName, instType, instArgs) else: self.components[instName] = Peripheral( self, typeSpec, instName, instType, instArgs) except Exception as e: traceback.print_exc() self.logger.error("Error while constructing part '%s.%s': %s" % (instType, instName, str(e))) def getParameterValueType(self, param, defaultType): ''' Infer the type of a parameter from its value unless a default type is provided. \ In the latter case the parameter's value is converted to that type. :param param: a parameter value :type param: one of bool,int,float,str :param defaultType: :type defaultType: one of bool,int,float,str :return: a pair (value,type) :rtype: tuple ''' paramValue, paramType = None, None if defaultType != None: if defaultType == str: paramValue, paramType = param, str elif defaultType == int: paramValue, paramType = int(param), int elif defaultType == float: paramValue, paramType = float(param), float elif defaultType == bool: paramType = bool paramValue = False if param == "False" else True if param == "True" else None paramValue, paramType = bool(param), float else: if param == 'True': paramValue, paramType = True, bool elif param == 'False': paramValue, paramType = True, bool elif self.INT_RE.match(param) is not None: paramValue, paramType = int(param), int else: try: paramValue, paramType = float(param), float except: paramValue, paramType = str(param), str return (paramValue, paramType) def parseParams(self, sysArgv): '''Parse actor arguments from the command line Compares the actual arguments to the formal arguments (from the model) and fills out the local parameter table accordingly. Generates a warning on extra arguments and raises an exception on required but missing ones. ''' self.params = {} formals = self.model["formals"] optList = [] for formal in formals: key = formal["name"] default = None if "default" not in formal else formal["default"] self.params[key] = default optList.append("%s=" % key) try: opts, _args = getopt.getopt(sysArgv, '', optList) except: self.logger.info("Error parsing actor options %s" % str(sysArgv)) return for opt in opts: optName2, optValue = opt optName = optName2[2:] # Drop two leading dashes if optName in self.params: defaultType = None if self.params[optName] == None else type( self.params[optName]) paramValue, paramType = self.getParameterValueType( optValue, defaultType) if self.params[optName] != None: if paramType != type(self.params[optName]): raise BuildError( "Type of default value does not match type of argument %s" % str((optName, optValue))) self.params[optName] = paramValue else: self.logger.info("Unknown argument %s - ignored" % optName) for param in self.params: if self.params[param] == None: raise BuildError("Required parameter %s missing" % param) def buildInstArgs(self, instName, formals, actuals): args = {} for formal in formals: argName = formal['name'] argValue = None actual = next( (actual for actual in actuals if actual['name'] == argName), None) defaultValue = None if 'default' in formal: defaultValue = formal['default'] if actual != None: assert (actual['name'] == argName) if 'param' in actual: paramName = actual['param'] if paramName in self.params: argValue = self.params[paramName] else: raise BuildError( "Unspecified parameter %s referenced in %s" % (paramName, instName)) elif 'value' in actual: argValue = actual['value'] else: raise BuildError("Actual parameter %s has no value" % argName) elif defaultValue != None: argValue = defaultValue else: raise BuildError("Argument %s in %s has no defined value" % (argName, instName)) args[argName] = argValue return args def isLocalMessage(self, msgTypeName): '''Return True if the message type is local ''' return msgTypeName in self.localNames def isInnerMessage(self, msgTypeName): '''Return True if the message type is internal ''' return msgTypeName in self.internalNames def getLocalIface(self): '''Return the IP address of the host-local network interface (usually 127.0.0.1) ''' return self.localHost def getGlobalIface(self): '''Return the IP address of the global network interface ''' return self.globalHost def getActorName(self): '''Return the name of this actor (as defined in the app model) ''' return self.name def getAppName(self): '''Return the name of the app this actor belongs to ''' return self.appName def getActorID(self): '''Returns an ID for this actor. The actor's id constructed from the host's IP address the actor's process id. The id is unique for a given host and actor run. ''' return self.actorID def setUUID(self, uuid): '''Sets the UUID for this actor. The UUID is dynamically generated (by the peer-to-peer network system) and is unique. ''' self.uuid = uuid def getUUID(self): '''Return the UUID for this actor. ''' return self.uuid def setupIfaces(self): '''Find the IP addresses of the (host-)local and network(-global) interfaces ''' (globalIPs, globalMACs, _globalNames, localIP) = getNetworkInterfaces() try: assert len(globalIPs) > 0 and len(globalMACs) > 0 except: self.logger.error("Error: no active network interface") raise globalIP = globalIPs[0] globalMAC = globalMACs[0] self.localHost = localIP self.globalHost = globalIP self.macAddress = globalMAC def setup(self): '''Perform a setup operation on the actor, after the initial construction but before the activation of parts ''' self.logger.info("setup") self.suffix = self.macAddress self.disco = DiscoClient(self, self.suffix) self.disco.start() # Start the discovery service client self.disco.registerApp( ) # Register this actor with the discovery service self.logger.info("actor registered with disco") self.deplc = DeplClient(self, self.suffix) self.deplc.start() ok = self.deplc.registerApp() self.logger.info("actor %s registered with depl" % ("is" if ok else "is not")) self.controls = {} self.controlMap = {} for inst in self.components: comp = self.components[inst] control = self.context.socket(zmq.PAIR) control.bind('inproc://part_' + inst + '_control') self.controls[inst] = control self.controlMap[id(control)] = comp if isinstance(comp, Part): self.components[inst].setup(control) else: self.components[inst].setup() def registerEndpoint(self, bundle): ''' Relay the endpoint registration message to the discovery service client ''' self.logger.info("registerEndpoint") result = self.disco.registerEndpoint(bundle) for res in result: (partName, portName, host, port) = res self.updatePart(partName, portName, host, port) def registerDevice(self, bundle): '''Relay the device registration message to the device interface service client ''' typeName, args = bundle msg = (self.appName, self.modelName, typeName, args) result = self.deplc.registerDevice(msg) return result def unregisterDevice(self, bundle): '''Relay the device unregistration message to the device interface service client ''' typeName, = bundle msg = (self.appName, self.modelName, typeName) result = self.deplc.unregisterDevice(msg) return result def activate(self): '''Activate the parts ''' self.logger.info("activate") for inst in self.components: self.components[inst].activate() def deactivate(self): '''Deactivate the parts ''' self.logger.info("deactivate") for inst in self.components: self.components[inst].deactivate() def recvChannelMessages(self, channel): '''Collect all messages from the channel queue and return them in a list ''' msgs = [] while True: try: msg = channel.recv(flags=zmq.NOBLOCK) msgs.append(msg) except zmq.Again: break return msgs def start(self): ''' Start and operate the actor (infinite polling loop) ''' self.logger.info("starting") self.discoChannel = self.disco.channel # Private channel to the discovery service self.deplChannel = self.deplc.channel self.poller = zmq.Poller() # Set up the poller self.poller.register(self.deplChannel, zmq.POLLIN) self.poller.register(self.discoChannel, zmq.POLLIN) for control in self.controls: self.poller.register(self.controls[control], zmq.POLLIN) while 1: sockets = dict(self.poller.poll()) if self.discoChannel in sockets: # If there is a message from a service, handle it msgs = self.recvChannelMessages(self.discoChannel) for msg in msgs: self.handleServiceUpdate( msg) # Handle message from disco service del sockets[self.discoChannel] elif self.deplChannel in sockets: msgs = self.recvChannelMessages(self.deplChannel) for msg in msgs: self.handleDeplMessage( msg) # Handle message from depl service del sockets[self.deplChannel] else: # Handle messages from the components. toDelete = [] for s in sockets: if s in self.controls.values(): part = self.controlMap[id(s)] msg = s.recv_pyobj( ) # receive python object from component self.handleEventReport(part, msg) # Report event toDelete += [s] for s in toDelete: del sockets[s] def handleServiceUpdate(self, msgBytes): ''' Handle a service update message from the discovery service ''' msgUpd = disco_capnp.DiscoUpd.from_bytes( msgBytes) # Parse the incoming message which = msgUpd.which() if which == 'portUpdate': msg = msgUpd.portUpdate client = msg.client actorHost = client.actorHost assert actorHost == self.globalHost # It has to be addressed to this actor actorName = client.actorName assert actorName == self.name instanceName = client.instanceName assert instanceName in self.components # It has to be for a part of this actor portName = client.portName scope = msg.scope socket = msg.socket host = socket.host port = socket.port if scope == "local": assert host == self.localHost self.updatePart(instanceName, portName, host, port) # Update the selected part def updatePart(self, instanceName, portName, host, port): ''' Ask a part to update itself ''' self.logger.info("updatePart %s" % str( (instanceName, portName, host, port))) part = self.components[instanceName] part.handlePortUpdate(portName, host, port) def handleDeplMessage(self, msgBytes): ''' Handle a message from the deployment service ''' msgUpd = deplo_capnp.DeplCmd.from_bytes( msgBytes) # Parse the incoming message which = msgUpd.which() if which == 'resourceMsg': what = msgUpd.resourceMsg.which() if what == 'resCPUX': self.handleCPULimit() elif what == 'resMemX': self.handleMemLimit() elif what == 'resSpcX': self.handleSpcLimit() elif what == 'resNetX': self.handleNetLimit() else: self.logger.error("unknown resource msg from deplo: '%s'" % what) pass elif which == 'reinstateCmd': self.handleReinstate() elif which == 'nicStateMsg': stateMsg = msgUpd.nicStateMsg state = str(stateMsg.nicState) self.handleNICStateChange(state) elif which == 'peerInfoMsg': peerMsg = msgUpd.peerInfoMsg state = str(peerMsg.peerState) uuid = peerMsg.uuid self.handlePeerStateChange(state, uuid) else: self.logger.error("unknown msg from deplo: '%s'" % which) pass def handleReinstate(self): self.logger.info('handleReinstate') self.poller.unregister(self.discoChannel) self.disco.reconnect() self.discoChannel = self.disco.channel self.poller.register(self.discoChannel, zmq.POLLIN) for inst in self.components: self.components[inst].handleReinstate() def handleNICStateChange(self, state): ''' Handle the NIC state change message: notify components ''' self.logger.info("handleNICStateChange") for component in self.components.values(): component.handleNICStateChange(state) def handlePeerStateChange(self, state, uuid): ''' Handle the peer state change message: notify components ''' self.logger.info("handlePeerStateChange") for component in self.components.values(): component.handlePeerStateChange(state, uuid) def handleCPULimit(self): ''' Handle the case when the CPU limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleCPULimit") for component in self.components.values(): component.handleCPULimit() def handleMemLimit(self): ''' Handle the case when the memory limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleMemLimit") for component in self.components.values(): component.handleMemLimit() def handleSpcLimit(self): ''' Handle the case when the file space limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleSpcLimit") for component in self.components.values(): component.handleSpcLimit() def handleNetLimit(self): ''' Handle the case when the net usage limit is exceeded: notify each component. If the component has defined a handler, it will be called. ''' self.logger.info("handleNetLimit") for component in self.components.values(): component.handleNetLimit() def handleEventReport(self, part, msg): '''Handle event report from a part The event report is forwarded to the deplo service. ''' partName = part.getName() typeName = part.getTypeName() bundle = ( partName, typeName, ) + (msg, ) self.deplc.reportEvent(bundle) def terminate(self): '''Terminate all functions of the actor. Terminate all components, and connections to the deplo/disco services. Finally exit the process. ''' self.logger.info("terminating") for component in self.components.values(): component.terminate() time.sleep(1.0) self.deplc.terminate() self.disco.terminate() if self.auth: self.auth.stop() # Clean up everything # self.context.destroy() # time.sleep(1.0) self.logger.info("terminated") os._exit(0)
class Command(LAVADaemonCommand): help = "LAVA log recorder" logger = None default_logfile = "/var/log/lava-server/lava-logs.log" def __init__(self, *args, **options): super(Command, self).__init__(*args, **options) self.logger = logging.getLogger("lava-logs") self.log_socket = None self.auth = None self.controler = None self.inotify_fd = None self.pipe_r = None self.poller = None self.cert_dir_path = None # List of logs self.jobs = {} # Keep test cases in memory self.test_cases = [] # Master status self.last_ping = 0 self.ping_interval = TIMEOUT def add_arguments(self, parser): super(Command, self).add_arguments(parser) net = parser.add_argument_group("network") net.add_argument('--socket', default='tcp://*:5555', help="Socket waiting for logs. Default: tcp://*:5555") net.add_argument('--master-socket', default='tcp://localhost:5556', help="Socket for master-slave communication. Default: tcp://localhost:5556") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument('--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-logs", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return # Create the sockets context = zmq.Context() self.log_socket = context.socket(zmq.PULL) self.controler = context.socket(zmq.ROUTER) self.controler.setsockopt(zmq.IDENTITY, b"lava-logs") # Limit the number of messages in the queue self.controler.setsockopt(zmq.SNDHWM, 2) # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5 # "Immediately readies that connection for data transfer with the master" self.controler.setsockopt(zmq.CONNECT_RID, b"master") if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.log_socket.setsockopt(zmq.IPV6, 1) self.controler.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate(options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except IOError as err: self.logger.error("[INIT] %s", err) self.auth.stop() return self.log_socket.curve_publickey = master_public self.log_socket.curve_secretkey = master_secret self.log_socket.curve_server = True self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_serverkey = master_public self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.cert_dir_path = options["slaves_certs"] self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.log_socket.bind(options['socket']) self.controler.connect(options['master_socket']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.log_socket, zmq.POLLIN) self.poller.register(self.controler, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] listening for logs") # PING right now: the master is waiting for this message to start # scheduling. self.controler.send_multipart([b"master", b"PING"]) try: self.main_loop() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) # Close the controler socket self.controler.close(linger=0) self.poller.unregister(self.controler) # Carefully close the logging socket as we don't want to lose messages self.logger.info("[EXIT] Disconnect logging socket and process messages") endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT)) self.logger.debug("[EXIT] unbinding from '%s'", endpoint) self.log_socket.unbind(endpoint) # Empty the queue try: while self.wait_for_messages(True): # Flush test cases cache for every iteration because we might # get killed soon. self.flush_test_cases() except BaseException as exc: self.logger.error("[EXIT] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Last flush self.flush_test_cases() self.logger.info("[EXIT] Closing the logging socket: the queue is empty") self.log_socket.close() if options['encrypt']: self.auth.stop() context.term() def flush_test_cases(self): if self.test_cases: self.logger.info("Saving %d test cases", len(self.test_cases)) TestCase.objects.bulk_create(self.test_cases) self.test_cases = [] def main_loop(self): last_gc = time.time() last_bulk_create = time.time() # Wait for messages # TODO: fix timeout computation while self.wait_for_messages(False): now = time.time() # Dump TestCase into the database if now - last_bulk_create > BULK_CREATE_TIMEOUT: last_bulk_create = now self.flush_test_cases() # Close old file handlers if now - last_gc > FD_TIMEOUT: last_gc = now # Iterate while removing keys is not compatible with iterator for job_id in list(self.jobs.keys()): # pylint: disable=consider-iterating-dictionary if now - self.jobs[job_id].last_usage > FD_TIMEOUT: self.logger.info("[%s] closing log file", job_id) self.jobs[job_id].close() del self.jobs[job_id] # Ping the master if now - self.last_ping > self.ping_interval: self.logger.debug("PING => master") self.last_ping = now self.controler.send_multipart([b"master", b"PING"]) def wait_for_messages(self, leaving): try: try: sockets = dict(self.poller.poll(TIMEOUT * 1000)) except zmq.error.ZMQError as exc: self.logger.error("[POLL] zmq error: %s", str(exc)) return True # Messages if sockets.get(self.log_socket) == zmq.POLLIN: self.logging_socket() return True # Signals elif sockets.get(self.pipe_r) == zmq.POLLIN: # remove the message from the queue os.read(self.pipe_r, 1) if not leaving: self.logger.info("[POLL] received a signal, leaving") return False else: self.logger.warning("[POLL] signal already handled, please wait for the process to exit") return True # Pong received elif sockets.get(self.controler) == zmq.POLLIN: self.controler_socket() return True # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", self.cert_dir_path) self.auth.configure_curve(domain='*', location=self.cert_dir_path) # Nothing received else: return not leaving except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset") connection.close() return True def logging_socket(self): msg = self.log_socket.recv_multipart() try: (job_id, message) = (u(m) for m in msg) # pylint: disable=unbalanced-tuple-unpacking except ValueError: # do not let a bad message stop the master. self.logger.error("[POLL] failed to parse log message, skipping: %s", msg) return try: scanned = yaml.load(message, Loader=yaml.CLoader) except yaml.YAMLError: self.logger.error("[%s] data are not valid YAML, dropping", job_id) return # Look for "results" level try: message_lvl = scanned["lvl"] message_msg = scanned["msg"] except TypeError: self.logger.error("[%s] not a dictionary, dropping", job_id) return except KeyError: self.logger.error( "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s", job_id, message) return # Find the handler (if available) if job_id not in self.jobs: # Query the database for the job try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return self.logger.info("[%s] receiving logs from a new job", job_id) # Create the sub directories (if needed) mkdir(job.output_dir) self.jobs[job_id] = JobHandler(job) if message_lvl == "results": try: job = TestJob.objects.get(pk=job_id) except TestJob.DoesNotExist: self.logger.error("[%s] unknown job id", job_id) return meta_filename = create_metadata_store(message_msg, job) new_test_case = map_scanned_results(results=message_msg, job=job, meta_filename=meta_filename) if new_test_case is None: self.logger.warning( "[%s] unable to map scanned results: %s", job_id, message) else: self.test_cases.append(new_test_case) # Look for lava.job result if message_msg.get("definition") == "lava" and message_msg.get("case") == "job": # Flush cached test cases self.flush_test_cases() if message_msg.get("result") == "pass": health = TestJob.HEALTH_COMPLETE health_msg = "Complete" else: health = TestJob.HEALTH_INCOMPLETE health_msg = "Incomplete" self.logger.info("[%s] job status: %s", job_id, health_msg) infrastructure_error = (message_msg.get("error_type") in ["Bug", "Configuration", "Infrastructure"]) if infrastructure_error: self.logger.info("[%s] Infrastructure error", job_id) # Update status. with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(health, infrastructure_error) job.save() # Mark the file handler as used self.jobs[job_id].last_usage = time.time() # n.b. logging here would produce a log entry for every message in every job. # The format is a list of dictionaries message = "- %s" % message # Write data self.jobs[job_id].write(message) def controler_socket(self): msg = self.controler.recv_multipart() try: master_id = u(msg[0]) action = u(msg[1]) ping_interval = int(msg[2]) if master_id != "master": self.logger.error("Invalid master id '%s'. Should be 'master'", master_id) return if action != "PONG": self.logger.error("Invalid answer '%s'. Should be 'PONG'", action) return except (IndexError, ValueError): self.logger.error("Invalid message '%s'", msg) return if ping_interval < TIMEOUT: self.logger.error("invalid ping interval (%d) too small", ping_interval) return self.logger.debug("master => PONG(%d)", ping_interval) self.ping_interval = ping_interval
class TaskQueue: """Outgoing task queue from the executor to the Interchange""" def __init__( self, address: str, port: int = 55001, identity: str = str(uuid.uuid4()), zmq_context=None, set_hwm=False, RCVTIMEO=None, SNDTIMEO=None, linger=None, ironhouse: bool = False, keys_dir: str = os.path.abspath(".curve"), mode: str = "client", ): """ Parameters ---------- address: str address to connect port: int Port to use identity : str Applies only to clients, where the identity must match the endpoint uuid. This will be utf-8 encoded on the wire. A random uuid4 string is set by default. mode: string Either 'client' or 'server' keys_dir : string Directory from which keys will be loaded for curve. ironhouse: Bool Only valid for server mode. Setting this flag switches the server to require client keys to be available on the server in the keys_dir. """ if zmq_context: self.context = zmq_context else: self.context = zmq.Context() self.mode = mode self.port = port self.ironhouse = ironhouse self.keys_dir = keys_dir assert self.mode in [ "client", "server", ], "Only two modes are supported: client, server" if self.mode == "server": print("Configuring server") self.zmq_socket = self.context.socket(zmq.ROUTER) self.zmq_socket.set(zmq.ROUTER_MANDATORY, 1) self.zmq_socket.set(zmq.ROUTER_HANDOVER, 1) print("Setting up auth-server") self.setup_server_auth() elif self.mode == "client": self.zmq_socket = self.context.socket(zmq.DEALER) self.setup_client_auth() self.zmq_socket.setsockopt(zmq.IDENTITY, identity.encode("utf-8")) else: raise ValueError( "TaskQueue must be initialized with mode set to 'server' or 'client'" ) if set_hwm: self.zmq_socket.set_hwm(0) if RCVTIMEO is not None: self.zmq_socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO) if SNDTIMEO is not None: self.zmq_socket.setsockopt(zmq.SNDTIMEO, SNDTIMEO) if linger is not None: self.zmq_socket.setsockopt(zmq.LINGER, linger) # all zmq setsockopt calls must be done before bind/connect is called if self.mode == "server": self.zmq_socket.bind(f"tcp://*:{port}") elif self.mode == "client": self.zmq_socket.connect(f"tcp://{address}:{port}") self.poller = zmq.Poller() self.poller.register(self.zmq_socket) os.makedirs(self.keys_dir, exist_ok=True) log.debug(f"Initializing Taskqueue:{self.mode} on port:{self.port}") def zmq_context(self): return self.context def add_client_key(self, endpoint_id, client_key): log.info("Adding client key") if self.ironhouse: # Use the ironhouse ZMQ pattern: http://hintjens.com/blog:49#toc6 with open(os.path.join(self.keys_dir, f"{endpoint_id}.key"), "w") as f: f.write(client_key) try: self.auth.configure_curve(domain="*", location=self.keys_dir) except Exception: log.exception("Failed to load keys from {self.keys_dir}") return def setup_server_auth(self): # Start an authenticator for this context. self.auth = ThreadAuthenticator(self.context) self.auth.start() self.auth.allow("127.0.0.1") # Tell the authenticator how to handle CURVE requests if not self.ironhouse: # Use the stonehouse ZMQ pattern: http://hintjens.com/blog:49#toc5 self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) server_secret_file = os.path.join(self.keys_dir, "server.key_secret") server_public, server_secret = zmq.auth.load_certificate( server_secret_file) self.zmq_socket.curve_secretkey = server_secret self.zmq_socket.curve_publickey = server_public self.zmq_socket.curve_server = True # must come before bind def setup_client_auth(self): # We need two certificates, one for the client and one for # the server. The client must know the server's public key # to make a CURVE connection. client_secret_file = os.path.join(self.keys_dir, "endpoint.key_secret") client_public, client_secret = zmq.auth.load_certificate( client_secret_file) self.zmq_socket.curve_secretkey = client_secret self.zmq_socket.curve_publickey = client_public # The client must know the server's public key to make a CURVE connection. server_public_file = os.path.join(self.keys_dir, "server.key") server_public, _ = zmq.auth.load_certificate(server_public_file) self.zmq_socket.curve_serverkey = server_public def get(self, block=True, timeout=1000): """ Parameters ---------- block : Bool Blocks until there's a message, Default is True timeout : int Milliseconds to wait. """ # timeout is in milliseconds if block is True: return self.zmq_socket.recv_multipart() socks = dict(self.poller.poll(timeout=timeout)) if self.zmq_socket in socks and socks[self.zmq_socket] == zmq.POLLIN: message = self.zmq_socket.recv_multipart() return message else: raise zmq.Again def register_client(self, message): return self.zmq_socket.send_multipart([message]) def put(self, dest, message, max_timeout=1000): """This function needs to be fast at the same time aware of the possibility of ZMQ pipes overflowing. The timeout increases slowly if contention is detected on ZMQ pipes. We could set copy=False and get slightly better latency but this results in ZMQ sockets reaching a broken state once there are ~10k tasks in flight. This issue can be magnified if each the serialized buffer itself is larger. Parameters ---------- dest : zmq_identity of the destination endpoint, must be a byte string message : py object Python object to send max_timeout : int Max timeout in milliseconds that we will wait for before raising an exception Raises ------ zmq.EAGAIN if the send failed. zmq.error.ZMQError: Host unreachable (if client disconnects?) """ if self.mode == "client": return self.zmq_socket.send_multipart([message]) else: return self.zmq_socket.send_multipart([dest, message]) def close(self): self.zmq_socket.close() self.context.term()
class SecurityManager(object): """ Manages the security related settings and actions. Attributes: private_cert_dir (str): The path towards the directory that stores private certificates. public_cert_dir (str): The path towards the directory that stores public certificates. temp_cert_dir (str): The path towards the temporary directory (used for generating new certificates). If not provided, it defaults to system's temporary directory. no_encryption (bool): Enable or disable encryption at peer level. Peers that use encryption can only connect to other peers that use encryption and vice-versa. public_file (str): The path towards the public certificate of this peer. private_file (str): The path towards the private certificate of this peer. auth_thread (ThreadAuthenticator): A separate thread used by zmq to authenticate our peers. """ def __init__(self, private_cert_dir, public_cert_dir, temp_cert_dir=None, no_encryption=False, *args, **kwargs): """ Constructor. Arguments: private_cert_dir (str): The path towards the directory that stores private certificates. public_cert_dir (str): The path towards the directory that stores public certificates. temp_cert_dir (str): The path towards the temporary directory (used for generating new certificates). Defaults to system's temporary directory. no_encryption (bool): Enable or disable encryption at peer level. Peers that use encryption can only connect to other peers that use encryption and vice-versa. """ super(SecurityManager, self).__init__(*args, **kwargs) self.private_cert_dir = private_cert_dir self.public_cert_dir = public_cert_dir if temp_cert_dir is None: self.temp_cert_dir = tempfile.gettempdir() else: self.temp_cert_dir = temp_cert_dir # Paths for certificates used by this instance. self.public_file = None self.private_file = None # Security settings. self.no_encryption = no_encryption # Authentication thread. self.auth_thread = None def start_auth(self, context): """ Starts the authentication thread if encryption is enabled. Arguments: context (zmq.Context): The context where the authentication thread will belong. """ if not self.no_encryption: logger.debug("Authenticator thread is being started") self.auth_thread = ThreadAuthenticator( context=context, encoding='utf-8', log=logging.getLogger('zmq_auth')) self.auth_thread.start() self.auth_thread.thread.name = 'zmq_auth' self.auth_thread.configure_curve(domain='*', location=self.public_cert_dir) def terminate_auth(self): """ Ends the authentication thread if encryption is enabled. This method should be written defensively, as the environment might not be fully set (an exception in :meth:`p2p0mq.app.theapp.LocalPeer.create` does not prevent this method from being executed). """ if self.auth_thread is not None: logger.debug("Authenticator thread is being stopped") self.auth_thread.stop() self.auth_thread = None def prepare_cert_store(self, uuid): """ Prepares the directory structure before it can be used by our authentication system. Arguments: uuid: The unique identification of local peer. """ if not os.path.isdir(self.private_cert_dir): os.makedirs(self.private_cert_dir) if not os.path.isdir(self.public_cert_dir): os.makedirs(self.public_cert_dir) if not os.path.isdir(self.temp_cert_dir): os.makedirs(self.temp_cert_dir) self.public_file, self.private_file = \ self.cert_pair_check_gen(uuid) def cert_pair_check_gen(self, uuid): """ Checks if the certificates exist. Generates them if they don't. Arguments: uuid: The unique identification of the peer. Usually, this is the local peer. """ cert_pub = self.cert_file_by_uuid(uuid, public=True) cert_prv = self.cert_file_by_uuid(uuid, public=False) pub_exists = os.path.isfile(cert_pub) prv_exists = os.path.isfile(cert_prv) if pub_exists and prv_exists: # Both files exist. Yey. pass elif pub_exists and not prv_exists: # The public certificate exists but is unusable without # the private one. raise RuntimeError( "The public certificate has been found at %s, " "which indicates that a key has been generated, " "but the private certificate is not at %s", cert_pub, cert_prv) elif not pub_exists and prv_exists: # The private certificate exists but the public one doesn't. # We can extract the key from the private one. with open(cert_prv, 'r') as fin: data = re.sub(r'.*private-key = "(.+)"', "", fin.read(), re.MULTILINE) with open(cert_pub, 'w') as fout: fout.write(data) else: # Neither exists. public_file, secret_file = \ zmq.auth.create_certificates( self.temp_cert_dir, '%r' % time()) shutil.move(public_file, cert_pub) shutil.move(secret_file, cert_prv) return cert_pub, cert_prv def cert_file_by_uuid(self, uuid, public=True): """ Computes the path of a certificate inside the certificate store based on the name of the peer. Arguments: uuid: The unique identification of the peer. Usually, this is the local peer. public (bool): If True it retrieves the path of the public certificate, if False it retrieves the path of the private certificate. """ if isinstance(uuid, bytes): uuid = uuid.decode('utf-8') pb_vs_pv = 'key' if public else 'key_secret' return os.path.join( self.public_cert_dir if public else self.private_cert_dir, '%s.%s' % (uuid, pb_vs_pv)) def cert_key_by_uuid(self, uuid, public=True): """ Reads the key from corresponding certificate file. Arguments: uuid: The unique identification of the peer. Usually, this is the local peer. public (bool): If True it retrieves the public key, if False it retrieves the private key. """ file = self.cert_file_by_uuid(uuid=uuid, public=public) logger.debug("%s certificate for uuid %s is loaded from %s", 'Public' if public else 'Private', uuid, file) if not os.path.exists(file): return None public_key, secret_key = zmq.auth.load_certificate(file) return public_key if public else secret_key def exchange_certificates(self, other): """ Copies the certificates so that the two instances can authenticate themselves to each other. Arguments: other (SecurityManager): The security manager of the other peer. """ if self.no_encryption: logger.error("Encryption was disabled in %s", self) return if other.no_encryption: logger.error("Encryption was disabled in %s", other) return shutil.copy( self.public_file, os.path.join(other.public_cert_dir, os.path.basename(self.public_file))) shutil.copy( other.public_file, os.path.join(self.public_cert_dir, os.path.basename(other.public_file))) self.auth_thread.configure_curve(domain='*', location=self.public_cert_dir) other.auth_thread.configure_curve(domain='*', location=other.public_cert_dir)
class Command(LAVADaemonCommand): """ worker_host is the hostname of the worker this field is set by the admin and could therefore be empty in a misconfigured instance. """ logger = None help = "LAVA dispatcher master" default_logfile = "/var/log/lava-server/lava-master.log" def __init__(self, *args, **options): super().__init__(*args, **options) self.auth = None self.controler = None self.event_socket = None self.poller = None self.pipe_r = None self.inotify_fd = None # List of logs # List of known dispatchers. At startup do not load this from the # database. This will help to know if the slave as restarted or not. self.dispatchers = { "lava-logs": SlaveDispatcher("lava-logs", online=False) } self.events = {"canceling": set(), "available_dt": set()} def add_arguments(self, parser): super().add_arguments(parser) net = parser.add_argument_group("network") net.add_argument( '--master-socket', default='tcp://*:5556', help="Socket for master-slave communication. Default: tcp://*:5556" ) net.add_argument('--event-url', default="tcp://localhost:5500", help="URL of the publisher") net.add_argument('--ipv6', default=False, action='store_true', help="Enable IPv6 on the listening sockets") net.add_argument('--encrypt', default=False, action='store_true', help="Encrypt messages") net.add_argument( '--master-cert', default='/etc/lava-dispatcher/certificates.d/master.key_secret', help="Certificate for the master socket") net.add_argument('--slaves-certs', default='/etc/lava-dispatcher/certificates.d', help="Directory for slaves certificates") def send_status(self, hostname): """ The master crashed, send a STATUS message to get the current state of jobs """ jobs = TestJob.objects.filter( actual_device__worker_host__hostname=hostname, state=TestJob.STATE_RUNNING) for job in jobs: self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname, job.actual_device.hostname) send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)]) def dispatcher_alive(self, hostname): if hostname not in self.dispatchers: # The server crashed: send a STATUS message self.logger.warning("Unknown dispatcher <%s> (server crashed)", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) self.send_status(hostname) # Mark the dispatcher as alive self.dispatchers[hostname].alive() def controler_socket(self): try: # We need here to use the zmq.NOBLOCK flag, otherwise we could block # the whole main loop where this function is called. msg = self.controler.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False # This is way to verbose for production and should only be activated # by (and for) developers # self.logger.debug("[CC] Receiving: %s", msg) # 1: the hostname (see ZMQ documentation) hostname = u(msg[0]) # 2: the action action = u(msg[1]) # Check that lava-logs only send PINGs if hostname == "lava-logs" and action != "PING": self.logger.error("%s => %s Invalid action from log daemon", hostname, action) return True # Handle the actions if action == 'HELLO' or action == 'HELLO_RETRY': self._handle_hello(hostname, action, msg) elif action == 'PING': self._handle_ping(hostname, action, msg) elif action == 'END': self._handle_end(hostname, action, msg) elif action == 'START_OK': self._handle_start_ok(hostname, action, msg) else: self.logger.error("<%s> sent unknown action=%s, args=(%s)", hostname, action, msg[1:]) return True def read_event_socket(self): try: msg = self.event_socket.recv_multipart(zmq.NOBLOCK) except zmq.error.Again: return False try: (topic, _, dt, username, data) = (u(m) for m in msg) data = simplejson.loads(data) except ValueError: self.logger.error("Invalid event: %s", msg) return True if topic.endswith(".testjob"): if data["state"] == "Canceling": self.events["canceling"].add(int(data["job"])) elif data["state"] == "Submitted": if "device_type" in data: self.events["available_dt"].add(data["device_type"]) elif topic.endswith(".device"): if data["state"] == "Idle" and data["health"] in [ "Good", "Unknown", "Looping" ]: self.events["available_dt"].add(data["device_type"]) return True def _handle_end(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) error_msg = msg[3] compressed_description = msg[4] except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return try: job = TestJob.objects.get(id=job_id) except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) # ACK even if the job is unknown to let the dispatcher # forget about it send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) return filename = os.path.join(job.output_dir, 'description.yaml') # If description.yaml already exists: a END was already received if os.path.exists(filename): self.logger.info("[%d] %s => END (duplicated), skipping", job_id, hostname) else: if compressed_description: self.logger.info("[%d] %s => END", job_id, hostname) else: self.logger.info( "[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)", job_id, hostname) with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_finished(TestJob.HEALTH_INCOMPLETE) if error_msg: self.logger.error("[%d] Error: %s", job_id, error_msg) job.failure_comment = error_msg job.save() # Create description.yaml even if it's empty # Allows to know when END messages are duplicated try: # Create the directory if it was not already created mkdir(os.path.dirname(filename)) # TODO: check that compressed_description is not "" description = lzma.decompress(compressed_description) with open(filename, 'w') as f_description: f_description.write(description.decode("utf-8")) if description: parse_job_description(job) except (OSError, lzma.LZMAError) as exc: self.logger.error("[%d] Unable to dump 'description.yaml'", job_id) self.logger.exception("[%d] %s", job_id, exc) # ACK the job and mark the dispatcher as alive send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)]) self.dispatcher_alive(hostname) def _handle_hello(self, hostname, action, msg): # Check the protocol version try: slave_version = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("%s => %s", hostname, action) if slave_version != PROTOCOL_VERSION: self.logger.error( "<%s> using protocol v%d while master is using v%d", hostname, slave_version, PROTOCOL_VERSION) return send_multipart_u(self.controler, [hostname, 'HELLO_OK']) # If the dispatcher is known and sent an HELLO, means that # the slave has restarted if hostname in self.dispatchers: if action == 'HELLO': self.logger.warning("Dispatcher <%s> has RESTARTED", hostname) else: # Assume the HELLO command was received, and the # action succeeded. self.logger.warning("Dispatcher <%s> was not confirmed", hostname) else: # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO # message. self.logger.warning("New dispatcher <%s>", hostname) self.dispatchers[hostname] = SlaveDispatcher(hostname) # Mark the dispatcher as alive self.dispatcher_alive(hostname) def _handle_ping(self, hostname, action, msg): # pylint: disable=unused-argument self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL) # Send back a signal send_multipart_u( self.controler, [hostname, 'PONG', str(PING_INTERVAL)]) self.dispatcher_alive(hostname) def _handle_start_ok(self, hostname, action, msg): # pylint: disable=unused-argument try: job_id = int(msg[2]) except (IndexError, ValueError): self.logger.error("Invalid message from <%s> '%s'", hostname, msg) return self.logger.info("[%d] %s => START_OK", job_id, hostname) try: with transaction.atomic(): # TODO: find a way to lock actual_device job = TestJob.objects.select_for_update() \ .get(id=job_id) job.go_state_running() job.save() except TestJob.DoesNotExist: self.logger.error("[%d] Unknown job", job_id) else: self.dispatcher_alive(hostname) def export_definition(self, job): # pylint: disable=no-self-use job_def = yaml.safe_load(job.definition) job_def['compatibility'] = job.pipeline_compatibility # no need for the dispatcher to retain comments return yaml.dump(job_def) def save_job_config(self, job, device_cfg, env_str, env_dut_str, dispatcher_cfg): output_dir = job.output_dir mkdir(output_dir) with open(os.path.join(output_dir, "job.yaml"), "w") as f_out: f_out.write(self.export_definition(job)) with open(os.path.join(output_dir, "device.yaml"), "w") as f_out: yaml.dump(device_cfg, f_out) if env_str: with open(os.path.join(output_dir, "env.yaml"), "w") as f_out: f_out.write(env_str) if env_dut_str: with open(os.path.join(output_dir, "env.dut.yaml"), "w") as f_out: f_out.write(env_dut_str) if dispatcher_cfg: with open(os.path.join(output_dir, "dispatcher.yaml"), "w") as f_out: f_out.write(dispatcher_cfg) def start_job(self, job): # Load job definition to get the variables for template # rendering job_def = yaml.safe_load(job.definition) job_ctx = job_def.get('context', {}) device = job.actual_device worker = device.worker_host # TODO: check that device_cfg is not None! device_cfg = device.load_configuration(job_ctx) # Try to load the dispatcher specific files and then fallback to the # default configuration files. env_str = load_optional_yaml_file( os.path.join(DISPATCHERS_PATH, worker.hostname, "env.yaml"), ENV_PATH) env_dut_str = load_optional_yaml_file( os.path.join(DISPATCHERS_PATH, worker.hostname, "env.dut.yaml"), ENV_DUT_PATH) dispatcher_cfg = load_optional_yaml_file( os.path.join(DISPATCHERS_PATH, worker.hostname, "dispatcher.yaml"), os.path.join(DISPATCHERS_PATH, "%s.yaml" % worker.hostname)) self.save_job_config(job, device_cfg, env_str, env_dut_str, dispatcher_cfg) self.logger.info("[%d] START => %s (%s)", job.id, worker.hostname, device.hostname) send_multipart_u(self.controler, [ worker.hostname, 'START', str(job.id), self.export_definition(job), yaml.dump(device_cfg), dispatcher_cfg, env_str, env_dut_str ]) # For multinode jobs, start the dynamic connections parent = job for sub_job in job.sub_jobs_list: if sub_job == parent or not sub_job.dynamic_connection: continue # inherit only enough configuration for dynamic_connection operation self.logger.info( "[%d] Trimming dynamic connection device configuration.", sub_job.id) min_device_cfg = parent.actual_device.minimise_configuration( device_cfg) self.save_job_config(sub_job, min_device_cfg, env_str, env_dut_str, dispatcher_cfg) self.logger.info("[%d] START => %s (connection)", sub_job.id, worker.hostname) send_multipart_u(self.controler, [ worker.hostname, 'START', str(sub_job.id), self.export_definition(sub_job), yaml.dump(min_device_cfg), dispatcher_cfg, env_str, env_dut_str ]) def start_jobs(self, jobs=None): """ Loop on all scheduled jobs and send the START message to the slave. """ # make the request atomic query = TestJob.objects.select_for_update() # Only select test job that are ready query = query.filter(state=TestJob.STATE_SCHEDULED) # Only start jobs on online workers query = query.filter( actual_device__worker_host__state=Worker.STATE_ONLINE) # exclude test job without a device: they are special test jobs like # dynamic connection. query = query.exclude(actual_device=None) # Allow for partial scheduling if jobs is not None: query = query.filter(id__in=jobs) # Loop on all jobs for job in query: msg = None try: self.start_job(job) except jinja2.TemplateNotFound as exc: self.logger.error("[%d] Template not found: '%s'", job.id, exc.message) msg = "Template not found: '%s'" % exc.message except jinja2.TemplateSyntaxError as exc: self.logger.error( "[%d] Template syntax error in '%s', line %d: %s", job.id, exc.name, exc.lineno, exc.message) msg = "Template syntax error in '%s', line %d: %s" % ( exc.name, exc.lineno, exc.message) except OSError as exc: self.logger.error("[%d] Unable to read '%s': %s", job.id, exc.filename, exc.strerror) msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror) except yaml.YAMLError as exc: self.logger.error("[%d] Unable to parse job definition: %s", job.id, exc) msg = "Cannot parse job definition: %s" % exc if msg: # Add the error as lava.job result metadata = { "case": "job", "definition": "lava", "error_type": "Infrastructure", "error_msg": msg, "result": "fail" } suite, _ = TestSuite.objects.get_or_create(name="lava", job=job) TestCase.objects.create(name="job", suite=suite, result=TestCase.RESULT_FAIL, metadata=yaml.dump(metadata)) job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True) job.save() def cancel_jobs(self, partial=False): # make the request atomic query = TestJob.objects.select_for_update() # Only select the test job that are canceling query = query.filter(state=TestJob.STATE_CANCELING) # Only cancel jobs on online workers query = query.filter( actual_device__worker_host__state=Worker.STATE_ONLINE) # Allow for partial canceling if partial: query = query.filter(id__in=list(self.events["canceling"])) # Loop on all jobs for job in query: worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host self.logger.info("[%d] CANCEL => %s", job.id, worker.hostname) send_multipart_u(self.controler, [worker.hostname, 'CANCEL', str(job.id)]) def handle(self, *args, **options): # Initialize logging. self.setup_logging("lava-master", options["level"], options["log_file"], FORMAT) self.logger.info("[INIT] Dropping privileges") if not self.drop_privileges(options['user'], options['group']): self.logger.error("[INIT] Unable to drop privileges") return filename = os.path.join(settings.MEDIA_ROOT, 'lava-master-config.yaml') self.logger.debug("[INIT] Dumping config to %s", filename) with open(filename, 'w') as output: yaml.dump(options, output) self.logger.info("[INIT] Marking all workers as offline") with transaction.atomic(): for worker in Worker.objects.select_for_update().all(): worker.go_state_offline() worker.save() # Create the sockets context = zmq.Context() self.controler = context.socket(zmq.ROUTER) self.event_socket = context.socket(zmq.SUB) if options['ipv6']: self.logger.info("[INIT] Enabling IPv6") self.controler.setsockopt(zmq.IPV6, 1) self.event_socket.setsockopt(zmq.IPV6, 1) if options['encrypt']: self.logger.info("[INIT] Starting encryption") try: self.auth = ThreadAuthenticator(context) self.auth.start() self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert']) master_public, master_secret = zmq.auth.load_certificate( options['master_cert']) self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) except OSError as err: self.logger.error(err) self.auth.stop() return self.controler.curve_publickey = master_public self.controler.curve_secretkey = master_secret self.controler.curve_server = True self.logger.debug("[INIT] Watching %s", options["slaves_certs"]) self.inotify_fd = watch_directory(options["slaves_certs"]) if self.inotify_fd is None: self.logger.error("[INIT] Unable to start inotify") self.controler.setsockopt(zmq.IDENTITY, b"master") # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42 # "If two clients use the same identity when connecting to a ROUTER # [...] the ROUTER socket shall hand-over the connection to the new # client and disconnect the existing one." self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1) self.controler.bind(options['master_socket']) self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC)) self.event_socket.connect(options['event_url']) # Poll on the sockets. This allow to have a # nice timeout along with polling. self.poller = zmq.Poller() self.poller.register(self.controler, zmq.POLLIN) self.poller.register(self.event_socket, zmq.POLLIN) if self.inotify_fd is not None: self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN) # Translate signals into zmq messages (self.pipe_r, _) = self.setup_zmq_signal_handler() self.poller.register(self.pipe_r, zmq.POLLIN) self.logger.info("[INIT] LAVA master has started.") self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION) try: self.main_loop(options) except BaseException as exc: self.logger.error("[CLOSE] Unknown exception raised, leaving!") self.logger.exception(exc) finally: # Drop controler socket: the protocol does handle lost messages self.logger.info( "[CLOSE] Closing the controler socket and dropping messages") self.controler.close(linger=0) self.event_socket.close(linger=0) if options['encrypt']: self.auth.stop() context.term() def main_loop(self, options): last_schedule = last_dispatcher_check = time.time() while True: try: try: # Compute the timeout now = time.time() timeout = min( SCHEDULE_INTERVAL - (now - last_schedule), PING_INTERVAL - (now - last_dispatcher_check)) # If some actions are remaining, decrease the timeout if any([self.events[k] for k in self.events.keys()]): timeout = min(timeout, 2) # Wait at least for 1ms timeout = max(timeout * 1000, 1) # Wait for data or a timeout sockets = dict(self.poller.poll(timeout)) except zmq.error.ZMQError: continue if sockets.get(self.pipe_r) == zmq.POLLIN: self.logger.info("[POLL] Received a signal, leaving") break # Command socket if sockets.get(self.controler) == zmq.POLLIN: while self.controler_socket( ): # Unqueue all pending messages pass # Events socket if sockets.get(self.event_socket) == zmq.POLLIN: while self.read_event_socket( ): # Unqueue all pending messages pass # Wait for the next iteration to handle the event. # In fact, the code that generated the event (lava-logs or # lava-server-gunicorn) needs some time to commit the # database transaction. # If we are too fast, the database object won't be # available (or in the right state) yet. continue # Inotify socket if sockets.get(self.inotify_fd) == zmq.POLLIN: os.read(self.inotify_fd, 4096) self.logger.debug("[AUTH] Reloading certificates from %s", options['slaves_certs']) self.auth.configure_curve(domain='*', location=options['slaves_certs']) # Check dispatchers status now = time.time() if now - last_dispatcher_check > PING_INTERVAL: for hostname, dispatcher in self.dispatchers.items(): if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT: if hostname == "lava-logs": self.logger.error( "[STATE] lava-logs goes OFFLINE") else: self.logger.error( "[STATE] Dispatcher <%s> goes OFFLINE", hostname) self.dispatchers[hostname].go_offline() last_dispatcher_check = now # Limit accesses to the database. This will also limit the rate of # CANCEL and START messages if time.time() - last_schedule > SCHEDULE_INTERVAL: if self.dispatchers["lava-logs"].online: schedule(self.logger) # Dispatch scheduled jobs with transaction.atomic(): self.start_jobs() else: self.logger.warning( "lava-logs is offline: can't schedule jobs") # Handle canceling jobs with transaction.atomic(): self.cancel_jobs() # Do not count the time taken to schedule jobs last_schedule = time.time() else: # Cancel the jobs and remove the jobs from the set if self.events["canceling"]: with transaction.atomic(): self.cancel_jobs(partial=True) self.events["canceling"] = set() # Schedule for available device-types if self.events["available_dt"]: jobs = schedule(self.logger, self.events["available_dt"]) self.events["available_dt"] = set() # Dispatch scheduled jobs with transaction.atomic(): self.start_jobs(jobs) except (OperationalError, InterfaceError): self.logger.info("[RESET] database connection reset.") # Closing the database connection will force Django to reopen # the connection connection.close() time.sleep(2)
def main(): """ Runs SEND either in transmitter or receiver mode """ parser = argparse.ArgumentParser() parser.add_argument( "-t", "--transmit", action="store_true", help="Flag indicating that user will be transmitting files" ) parser.add_argument( "-r", "--receive", action="store_true", help="Flag indicating that user will be receiving files" ) parser.add_argument( "--location", help="Location of files to send/receive. Can be a specific file if tx." ) parser.add_argument( "--ip", help="IP Address to form connection with" ) parser.add_argument( "--port", nargs='?', const=6000, default=6000, type=int, help="Port to form connection with (only needed if using non-default)" ) parser.add_argument( "--public_key", nargs='?', help="Public Key of transmitter in plain-text (only needed if receiver)" ) args=parser.parse_args() # Security Authentication Thread _generate_security_keys() authenticator = ThreadAuthenticator(manager.ctx) authenticator.start() whitelist = [ "127.0.0.1", args.ip ] authenticator.allow(*whitelist) authenticator.configure_curve(domain="*", location=PUBKEYS) try: if args.transmit: thread = manager.publish_folder( args.port, args.location ) elif args.receive: thread = manager.subscribe_folder( args.ip, args.port, args.location, args.public_key ) else: raise ValueError(f"User did not specify transmit/receive") except (OSError, ValueError): raise except KeyboardInterrupt: pass finally: # Keep things rolling until the transfer is done or the thread dies from # timing out while thread.isAlive(): pass thread.join() # Clean up and close everything out authenticator.stop() # Use destroy versus term: https://github.com/zeromq/pyzmq/issues/991 manager.ctx.destroy()