Exemple #1
0
class Authenticator(object):
    _authenticators = {}

    @classmethod
    def instance(cls, public_keys_dir):
        '''Please avoid create multi instance'''
        if public_keys_dir in cls._authenticators:
            return cls._authenticators[public_keys_dir]
        new_instance = cls(public_keys_dir)
        cls._authenticators[public_keys_dir] = new_instance
        return new_instance

    def __init__(self, public_keys_dir):
        self._auth = ThreadAuthenticator(zmq.Context.instance())
        self._auth.start()
        self._auth.allow('*')
        self._auth.configure_curve(domain='*', location=public_keys_dir)

    def set_server_key(self, zmq_socket, server_secret_key_path):
        '''must call before bind'''
        load_and_set_key(zmq_socket, server_secret_key_path)
        zmq_socket.curve_server = True

    def set_client_key(self, zmq_socket, client_secret_key_path, server_public_key_path):
        '''must call before bind'''
        load_and_set_key(zmq_socket, client_secret_key_path)
        server_public, _ = zmq.auth.load_certificate(server_public_key_path)
        zmq_socket.curve_serverkey = server_public

    def stop(self):
        self._auth.stop()
Exemple #2
0
def main():
    localhost = socket_m.getfqdn()

    port = "5556"
    # ip = "*"
    ip = socket_m.gethostbyaddr(localhost)[2][0]

    context = zmq.Context()
    socket = context.socket(zmq.PULL)
    socket.zap_domain = b'global'
    socket.bind("tcp://" + ip + ":%s" % port)

    auth = ThreadAuthenticator(context)

    host = localhost
    # host = asap3-p00
    whitelist = socket_m.gethostbyaddr(host)[2][0]
    # whitelist = None
    auth.start()

    if whitelist is None:
        auth.auth = None
    else:
        auth.allow(whitelist)

    try:
        while True:
            message = socket.recv_multipart()
            print("received reply ", message)
    except KeyboardInterrupt:
        pass
    finally:
        auth.stop()
Exemple #3
0
def run_mdp_broker():
    args = docopt("""Usage:
        mdp-broker [options] <config>

    Options:
        -h --help                 show this help message and exit
        -s --secure               generate (and print) client & broker keys for a secure server
    """)
    global log
    _setup_logging(args['<config>'])

    log = logging.getLogger(__name__)

    cp = ConfigParser()
    cp.read(args['<config>'])

    # Parse settings a bit
    raw = dict(
        (option, cp.get('mdp-broker', option))
        for option in cp.options('mdp-broker'))
    s = SettingsSchema().to_python(raw)

    if args['--secure']:
        broker_key = Key.generate()
        client_key = Key.generate()
        s['key'] = dict(
            broker=broker_key,
            client=client_key)
        log.info('Auto-generated keys: %s_%s_%s',
            broker_key.public, client_key.public, client_key.secret)
        log.info(' broker.public: %s', broker_key.public)
        log.info(' client.public: %s', client_key.public)
        log.info(' client.secret: %s', client_key.secret)

    if s['key']:
        log.info('Starting secure mdp-broker on %s', s['uri'])
        auth = ThreadAuthenticator()
        auth.start()
        auth.thread.authenticator.certs['*'] = {
            s['key']['client'].public: 'OK'}

        broker = SecureMajorDomoBroker(s['key']['broker'], s['uri'])
    else:
        log.info('Starting mdp-broker on %s', s['uri'])
        broker = MajorDomoBroker(s['uri'])
    try:
        broker.serve_forever()
    except:
        auth.stop()
        raise
def main():
    auth = ThreadAuthenticator(zmq.Context.instance())
    auth.start()
    auth.allow('127.0.0.1')
    # Tell the authenticator how to handle CURVE requests
    auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)

    key = Key.load('example/broker.key_secret')
    broker = SecureMajorDomoBroker(key, sys.argv[1])
    try:
        broker.serve_forever()
    except KeyboardInterrupt:
        auth.stop()
        raise
def main():
    auth = ThreadAuthenticator(zmq.Context.instance())
    auth.start()
    auth.allow('127.0.0.1')
    # Tell the authenticator how to handle CURVE requests
    auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)

    key = Key.load('example/broker.key_secret')
    broker = SecureMajorDomoBroker(key, sys.argv[1])
    try:
        broker.serve_forever()
    except KeyboardInterrupt:
        auth.stop()
        raise
Exemple #6
0
def run_mdp_broker():
    args = docopt("""Usage:
        mdp-broker [options] <config>

    Options:
        -h --help                 show this help message and exit
        -s --secure               generate (and print) client & broker keys for a secure server
    """)
    global log
    _setup_logging(args['<config>'])

    log = logging.getLogger(__name__)

    cp = ConfigParser()
    cp.read(args['<config>'])

    # Parse settings a bit
    raw = dict((option, cp.get('mdp-broker', option))
               for option in cp.options('mdp-broker'))
    s = SettingsSchema().to_python(raw)

    if args['--secure']:
        broker_key = Key.generate()
        client_key = Key.generate()
        s['key'] = dict(broker=broker_key, client=client_key)
        log.info('Auto-generated keys: %s_%s_%s', broker_key.public,
                 client_key.public, client_key.secret)
        log.info(' broker.public: %s', broker_key.public)
        log.info(' client.public: %s', client_key.public)
        log.info(' client.secret: %s', client_key.secret)

    if s['key']:
        log.info('Starting secure mdp-broker on %s', s['uri'])
        auth = ThreadAuthenticator()
        auth.start()
        auth.thread.authenticator.certs['*'] = {
            s['key']['client'].public: 'OK'
        }

        broker = SecureMajorDomoBroker(s['key']['broker'], s['uri'])
    else:
        log.info('Starting mdp-broker on %s', s['uri'])
        broker = MajorDomoBroker(s['uri'])
    try:
        broker.serve_forever()
    except:
        auth.stop()
        raise
Exemple #7
0
def main():
    port = "5556"
    socket_ip = "*"
    # ip = socket.getfqdn()

    context = zmq.Context()
    auth = ThreadAuthenticator(context)
    auth.start()

    whitelist = [socket.getfqdn()]
    for host in whitelist:
        hostname, tmp, ip = socket.gethostbyaddr(host)
        auth.allow(ip[0])

    zmq_socket = context.socket(zmq.PUSH)
    zmq_socket.zap_domain = b'global'
    zmq_socket.bind("tcp://" + socket_ip + ":%s" % port)

    try:
        for i in range(5):
            message = ["World"]
            print("Send: ", message)
            res = zmq_socket.send_multipart(message, copy=False, track=True)
            if res.done:
                print("res: done")
            else:
                print("res: waiting")
                res.wait()
                print("res: waiting...")
            print("sleeping...")
            if i == 1:
                auth.stop()
                zmq_socket.close(0)

                auth.start()
                #                ip = socket.gethostbyaddr(socket.getfqdn())[2]
                #                auth.allow(ip[0])
                ip = socket.gethostbyaddr(socket.getfqdn())[2]
                auth.deny(ip[0])
                zmq_socket = context.socket(zmq.PUSH)
                zmq_socket.zap_domain = b'global'
                zmq_socket.bind("tcp://" + socket_ip + ":%s" % port)
            time.sleep(1)
            print("sleeping...done")
            i += 1
    finally:
        auth.stop()
def setup_auth():
    global _auth
    assert _options is not None
    auth = _options.get('auth',None)
    if auth is None:
        return
    base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..'))
    try:
        _auth = ThreadAuthenticator(_zctx)
        _auth.start()
        whitelist = auth.get('whitelist',None)
        if whitelist is not None:
            _auth.allow(whitelist)
        public_path = auth.get('public_key_dir','public_keys')
        _auth.configure_curve(domain='*',location=getExistsPath(base_dir,public_path))
        private_dir = getExistsPath(base_dir,auth.get('private_key_dir','private_keys'))
        private_key = os.path.join(private_dir,auth.get('private_key_file','server.key_secret'))
        server_public,server_private = zmq.auth.load_certificate(private_key)
        _sock.curve_secretkey = server_private
        _sock.curve_publickey = server_public
        _sock.curve_server = True
    except:
        _auth.stop()
        _auth = None
Exemple #9
0
  def run(self):
    self.set_status("Server Startup")
    
    self.set_status("Creating zmq Contexts",1)
    serverctx = zmq.Context() 
    
    self.set_status("Starting zmq ThreadedAuthenticator",1)
    #serverauth = zmq.auth.ThreadedAuthenticator(serverctx)
    serverauth = ThreadAuthenticator(serverctx)
    serverauth.start()
    
    with taco.globals.settings_lock:
      bindip     = taco.globals.settings["Application IP"]
      bindport   = taco.globals.settings["Application Port"]
      localuuid  = taco.globals.settings["Local UUID"]
      publicdir  = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/public/"))
      privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/private/"))

    self.set_status("Configuring Curve to use publickey dir:" + publicdir)
    serverauth.configure_curve(domain='*', location=publicdir)
    #auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)

    self.set_status("Creating Server Context",1)
    server = serverctx.socket(zmq.REP)
    server.setsockopt(zmq.LINGER, 0)

    self.set_status("Loading Server Certs",1)
    server_public, server_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-server.key_secret")))
    server.curve_secretkey = server_secret
    server.curve_publickey = server_public
   
    server.curve_server = True
    if bindip == "0.0.0.0": bindip ="*"
    self.set_status("Server is now listening for encrypted ZMQ connections @ "+ "tcp://" + bindip +":" + str(bindport)) 
    server.bind("tcp://" + bindip +":" + str(bindport))
    
    poller = zmq.Poller()
    poller.register(server, zmq.POLLIN|zmq.POLLOUT)

    while not self.stop.is_set():
      socks = dict(poller.poll(200))
      if server in socks and socks[server] == zmq.POLLIN:
        #self.set_status("Getting a request")
        data = server.recv()
        with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data))
        (client_uuid,reply) = taco.commands.Proccess_Request(data)
        if client_uuid!="0": self.set_client_last_request(client_uuid)
      socks = dict(poller.poll(10))
      if server in socks and socks[server] == zmq.POLLOUT:
        #self.set_status("Replying to a request")
        with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(reply))
        server.send(reply)

        


    self.set_status("Stopping zmq server with 0 second linger")
    server.close(0)
    self.set_status("Stopping zmq ThreadedAuthenticator")
    serverauth.stop() 
    serverctx.term()
    self.set_status("Server Exit")    
Exemple #10
0
class MispZmq:
    message_count = 0
    publish_count = 0

    monitor_thread = None
    auth = None
    socket = None
    pidfile = None

    r: redis.StrictRedis
    namespace: str

    def __init__(self):
        self._logger = logging.getLogger()

        self.tmp_location = Path(__file__).parent.parent / "tmp"
        self.pidfile = self.tmp_location / "mispzmq.pid"
        if self.pidfile.exists():
            with open(self.pidfile.as_posix()) as f:
                pid = f.read()
            if check_pid(pid):
                raise Exception(
                    "mispzmq already running on PID {}".format(pid))
            else:
                # Cleanup
                self.pidfile.unlink()
        if (self.tmp_location / "mispzmq_settings.json").exists():
            self._setup()
        else:
            raise Exception("The settings file is missing.")

    def _setup(self):
        with open((self.tmp_location /
                   "mispzmq_settings.json").as_posix()) as settings_file:
            self.settings = json.load(settings_file)
        self.namespace = self.settings["redis_namespace"]
        self.r = redis.StrictRedis(host=self.settings["redis_host"],
                                   db=self.settings["redis_database"],
                                   password=self.settings["redis_password"],
                                   port=self.settings["redis_port"],
                                   decode_responses=True)
        self.timestamp_settings = time.time()
        self._logger.debug("Connected to Redis {}:{}/{}".format(
            self.settings["redis_host"], self.settings["redis_port"],
            self.settings["redis_database"]))

    def _setup_zmq(self):
        context = zmq.Context()

        if "username" in self.settings and self.settings["username"]:
            if "password" not in self.settings or not self.settings["password"]:
                raise Exception(
                    "When username is set, password cannot be empty.")

            self.auth = ThreadAuthenticator(context)
            self.auth.start()
            self.auth.configure_plain(domain="*",
                                      passwords={
                                          self.settings["username"]:
                                          self.settings["password"]
                                      })
        else:
            if self.auth:
                self.auth.stop()
            self.auth = None

        self.socket = context.socket(zmq.PUB)
        if self.settings["username"]:
            self.socket.plain_server = True  # must come before bind
        self.socket.bind("tcp://{}:{}".format(self.settings["host"],
                                              self.settings["port"]))
        self._logger.debug("ZMQ listening on tcp://{}:{}".format(
            self.settings["host"], self.settings["port"]))

        if self._logger.isEnabledFor(logging.DEBUG):
            monitor = self.socket.get_monitor_socket()
            self.monitor_thread = threading.Thread(target=event_monitor,
                                                   args=(monitor,
                                                         self._logger))
            self.monitor_thread.start()
        else:
            if self.monitor_thread:
                self.socket.disable_monitor()
            self.monitor_thread = None

    def _handle_command(self, command):
        if command == "kill":
            self._logger.info("Kill command received, shutting down.")
            self.clean()
            sys.exit()

        elif command == "reload":
            self._logger.info(
                "Reload command received, reloading settings from file.")
            self._setup()
            self._setup_zmq()

        elif command == "status":
            self._logger.info(
                "Status command received, responding with latest stats.")
            self.r.delete("{}:status".format(self.namespace))
            self.r.lpush(
                "{}:status".format(self.namespace),
                json.dumps({
                    "timestamp": time.time(),
                    "timestampSettings": self.timestamp_settings,
                    "publishCount": self.publish_count,
                    "messageCount": self.message_count
                }))
        else:
            self._logger.warning(
                "Received invalid command '{}'.".format(command))

    def _create_pid_file(self):
        with open(self.pidfile.as_posix(), "w") as f:
            f.write(str(os.getpid()))

    def _pub_message(self, topic, data):
        self.socket.send_string("{} {}".format(topic, data))

    def clean(self):
        if self.monitor_thread:
            self.socket.disable_monitor()
        if self.auth:
            self.auth.stop()
        if self.socket:
            self.socket.close()
        if self.pidfile:
            self.pidfile.unlink()

    def main(self):
        self._create_pid_file()
        self._setup_zmq()
        time.sleep(1)

        status_array = [
            "And when you're dead I will be still alive.",
            "And believe me I am still alive.",
            "I'm doing science and I'm still alive.",
            "I feel FANTASTIC and I'm still alive.",
            "While you're dying I'll be still alive.",
        ]
        topics = [
            "misp_json", "misp_json_event", "misp_json_attribute",
            "misp_json_sighting", "misp_json_organisation", "misp_json_user",
            "misp_json_conversation", "misp_json_object",
            "misp_json_object_reference", "misp_json_audit", "misp_json_tag",
            "misp_json_warninglist"
        ]

        lists = ["{}:command".format(self.namespace)]
        for topic in topics:
            lists.append("{}:data:{}".format(self.namespace, topic))

        while True:
            data = self.r.blpop(lists, timeout=10)

            if data is None:
                # redis timeout expired
                current_time = int(time.time())
                time_delta = current_time - int(self.timestamp_settings)
                status_entry = int(time_delta / 10 % 5)
                status_message = {
                    "status": status_array[status_entry],
                    "uptime": current_time - int(self.timestamp_settings)
                }
                self._pub_message("misp_json_self", json.dumps(status_message))
                self._logger.debug(
                    "No message received for 10 seconds, sending ZMQ status message."
                )
            else:
                key, value = data
                key = key.replace("{}:".format(self.namespace), "")
                if key == "command":
                    self._handle_command(value)
                elif key.startswith("data:"):
                    topic = key.split(":")[1]
                    self._logger.debug(
                        "Received data for topic '{}', sending to ZMQ.".format(
                            topic))
                    self._pub_message(topic, value)
                    self.message_count += 1
                    if topic == "misp_json":
                        self.publish_count += 1
                else:
                    self._logger.warning(
                        "Received invalid message '{}'.".format(key))
Exemple #11
0
class Actor(object):
    '''The actor class implements all the management and control functions over its components
    
    :param gModel: the JSON-based dictionary holding the model for the app this actor belongs to.
    :type gModel: dict
    :param gModelName: the name of the top-level model for the app
    :type gModelName: str
    :param aName: name of the actor. It is an index into the gModel that points to the part of the model specific to the actor
    :type aName: str
    :param sysArgv: list of arguments for the actor: -key1 value1 -key2 value2 ...
    :type list:
         
    '''
    def __init__(self, gModel, gModelName, aName, sysArgv):
        '''
        Constructor
        '''
        self.logger = logging.getLogger(__name__)
        self.inst_ = self
        self.appName = gModel["name"]
        self.modelName = gModelName
        self.name = aName
        self.pid = os.getpid()
        self.uuid = None
        self.setupIfaces()
        # Assumption : pid is a 4 byte int
        self.actorID = ipaddress.IPv4Address(
            self.globalHost).packed + self.pid.to_bytes(4, 'big')
        self.suffix = ""
        if aName not in gModel["actors"]:
            raise BuildError('Actor "%s" unknown' % aName)
        self.model = gModel["actors"][
            aName]  # Fetch the relevant content from the model

        self.INT_RE = re.compile(r"^[-]?\d+$")
        self.parseParams(sysArgv)

        # Use czmq's context
        czmq_ctx = Zsys.init()
        self.context = zmq.Context.shadow(czmq_ctx.value)
        Zsys.handler_reset()  # Reset previous signal handler

        # Context for app sockets
        self.appContext = zmq.Context()

        if Config.SECURITY:
            (self.public_key,
             self.private_key) = zmq.auth.load_certificate(const.appCertFile)
            hosts = ['127.0.0.1']
            try:
                with open(const.appDescFile, 'r') as f:
                    content = yaml.load(f)
                    hosts += content.hosts
            except:
                pass

            self.auth = ThreadAuthenticator(self.appContext)
            self.auth.start()
            self.auth.allow(*hosts)
            self.auth.configure_curve(domain='*',
                                      location=zmq.auth.CURVE_ALLOW_ANY)
        else:
            (self.public_key, self.private_key) = (None, None)
            self.auth = None
            self.appContext = self.context

        try:
            if os.path.isfile(const.logConfFile) and os.access(
                    const.logConfFile, os.R_OK):
                spdlog_setup.from_file(const.logConfFile)
        except Exception as e:
            self.logger.error("error while configuring componentLogger: %s" %
                              repr(e))

        messages = gModel[
            "messages"]  # Global message types (global on the network)
        self.messageNames = []
        for messageSpec in messages:
            self.messageNames.append(messageSpec["name"])

        locals_ = self.model[
            "locals"]  # Local message types (local to the host)
        self.localNames = []
        for messageSpec in locals_:
            self.localNames.append(messageSpec["type"])

        internals = self.model[
            "internals"]  # Internal message types (internal to the actor process)
        self.internalNames = []
        for messageSpec in internals:
            self.internalNames.append(messageSpec["type"])

        self.components = {}
        instSpecs = self.model["instances"]
        compSpecs = gModel["components"]
        ioSpecs = gModel["devices"]
        for instName in instSpecs:  # Create the component instances: the 'parts'
            instSpec = instSpecs[instName]
            instType = instSpec['type']
            if instType in compSpecs:
                typeSpec = compSpecs[instType]
                ioComp = False
            elif instType in ioSpecs:
                typeSpec = ioSpecs[instType]
                ioComp = True
            else:
                raise BuildError(
                    'Component type "%s" for instance "%s" is undefined' %
                    (instType, instName))
            instFormals = typeSpec['formals']
            instActuals = instSpec['actuals']
            instArgs = self.buildInstArgs(instName, instFormals, instActuals)

            # Check whether the component is C++ component
            ccComponentFile = 'lib' + instType.lower() + '.so'
            ccComp = os.path.isfile(ccComponentFile)
            try:
                if not ioComp:
                    if ccComp:
                        modObj = importlib.import_module('lib' +
                                                         instType.lower())
                        self.components[instName] = modObj.create_component_py(
                            self, self.model, typeSpec, instName, instType,
                            instArgs, self.appName, self.name)
                    else:
                        self.components[instName] = Part(
                            self, typeSpec, instName, instType, instArgs)
                else:
                    self.components[instName] = Peripheral(
                        self, typeSpec, instName, instType, instArgs)
            except Exception as e:
                traceback.print_exc()
                self.logger.error("Error while constructing part '%s.%s': %s" %
                                  (instType, instName, str(e)))

    def getParameterValueType(self, param, defaultType):
        ''' Infer the type of a parameter from its value unless a default type is provided. \
            In the latter case the parameter's value is converted to that type.
            
            :param param: a parameter value
            :type param: one of bool,int,float,str
            :param defaultType:
            :type defaultType: one of bool,int,float,str
            :return: a pair (value,type)
            :rtype: tuple
             
        '''
        paramValue, paramType = None, None
        if defaultType != None:
            if defaultType == str:
                paramValue, paramType = param, str
            elif defaultType == int:
                paramValue, paramType = int(param), int
            elif defaultType == float:
                paramValue, paramType = float(param), float
            elif defaultType == bool:
                paramType = bool
                paramValue = False if param == "False" else True if param == "True" else None
                paramValue, paramType = bool(param), float
        else:
            if param == 'True':
                paramValue, paramType = True, bool
            elif param == 'False':
                paramValue, paramType = True, bool
            elif self.INT_RE.match(param) is not None:
                paramValue, paramType = int(param), int
            else:
                try:
                    paramValue, paramType = float(param), float
                except:
                    paramValue, paramType = str(param), str
        return (paramValue, paramType)

    def parseParams(self, sysArgv):
        '''Parse actor arguments from the command line
        
        Compares the actual arguments to the formal arguments (from the model) and
        fills out the local parameter table accordingly. Generates a warning on 
        extra arguments and raises an exception on required but missing ones.
           
        '''
        self.params = {}
        formals = self.model["formals"]
        optList = []
        for formal in formals:
            key = formal["name"]
            default = None if "default" not in formal else formal["default"]
            self.params[key] = default
            optList.append("%s=" % key)
        try:
            opts, _args = getopt.getopt(sysArgv, '', optList)
        except:
            self.logger.info("Error parsing actor options %s" % str(sysArgv))
            return
        for opt in opts:
            optName2, optValue = opt
            optName = optName2[2:]  # Drop two leading dashes
            if optName in self.params:
                defaultType = None if self.params[optName] == None else type(
                    self.params[optName])
                paramValue, paramType = self.getParameterValueType(
                    optValue, defaultType)
                if self.params[optName] != None:
                    if paramType != type(self.params[optName]):
                        raise BuildError(
                            "Type of default value does not match type of argument %s"
                            % str((optName, optValue)))
                self.params[optName] = paramValue
            else:
                self.logger.info("Unknown argument %s - ignored" % optName)
        for param in self.params:
            if self.params[param] == None:
                raise BuildError("Required parameter %s missing" % param)

    def buildInstArgs(self, instName, formals, actuals):
        args = {}
        for formal in formals:
            argName = formal['name']
            argValue = None
            actual = next(
                (actual for actual in actuals if actual['name'] == argName),
                None)
            defaultValue = None
            if 'default' in formal:
                defaultValue = formal['default']
            if actual != None:
                assert (actual['name'] == argName)
                if 'param' in actual:
                    paramName = actual['param']
                    if paramName in self.params:
                        argValue = self.params[paramName]
                    else:
                        raise BuildError(
                            "Unspecified parameter %s referenced in %s" %
                            (paramName, instName))
                elif 'value' in actual:
                    argValue = actual['value']
                else:
                    raise BuildError("Actual parameter %s has no value" %
                                     argName)
            elif defaultValue != None:
                argValue = defaultValue
            else:
                raise BuildError("Argument %s in %s has no defined value" %
                                 (argName, instName))
            args[argName] = argValue
        return args

    def isLocalMessage(self, msgTypeName):
        '''Return True if the message type is local
        
        '''
        return msgTypeName in self.localNames

    def isInnerMessage(self, msgTypeName):
        '''Return True if the message type is internal
        
        '''
        return msgTypeName in self.internalNames

    def getLocalIface(self):
        '''Return the IP address of the host-local network interface (usually 127.0.0.1) 
        '''
        return self.localHost

    def getGlobalIface(self):
        '''Return the IP address of the global network interface
        '''
        return self.globalHost

    def getActorName(self):
        '''Return the name of this actor (as defined in the app model)
        '''
        return self.name

    def getAppName(self):
        '''Return the name of the app this actor belongs to
        '''
        return self.appName

    def getActorID(self):
        '''Returns an ID for this actor.
        
        The actor's id constructed from the host's IP address the actor's process id. 
        The id is unique for a given host and actor run.
        '''
        return self.actorID

    def setUUID(self, uuid):
        '''Sets the UUID for this actor.
        
        The UUID is dynamically generated (by the peer-to-peer network system)
        and is unique. 
        '''
        self.uuid = uuid

    def getUUID(self):
        '''Return the UUID for this actor. 
        '''
        return self.uuid

    def setupIfaces(self):
        '''Find the IP addresses of the (host-)local and network(-global) interfaces
        
        '''
        (globalIPs, globalMACs, _globalNames, localIP) = getNetworkInterfaces()
        try:
            assert len(globalIPs) > 0 and len(globalMACs) > 0
        except:
            self.logger.error("Error: no active network interface")
            raise
        globalIP = globalIPs[0]
        globalMAC = globalMACs[0]
        self.localHost = localIP
        self.globalHost = globalIP
        self.macAddress = globalMAC

    def setup(self):
        '''Perform a setup operation on the actor, after  the initial construction 
        but before the activation of parts
        
        '''
        self.logger.info("setup")
        self.suffix = self.macAddress
        self.disco = DiscoClient(self, self.suffix)
        self.disco.start()  # Start the discovery service client
        self.disco.registerApp(
        )  # Register this actor with the discovery service
        self.logger.info("actor registered with disco")
        self.deplc = DeplClient(self, self.suffix)
        self.deplc.start()
        ok = self.deplc.registerApp()
        self.logger.info("actor %s registered with depl" %
                         ("is" if ok else "is not"))

        self.controls = {}
        self.controlMap = {}
        for inst in self.components:
            comp = self.components[inst]
            control = self.context.socket(zmq.PAIR)
            control.bind('inproc://part_' + inst + '_control')
            self.controls[inst] = control
            self.controlMap[id(control)] = comp
            if isinstance(comp, Part):
                self.components[inst].setup(control)
            else:
                self.components[inst].setup()

    def registerEndpoint(self, bundle):
        '''
        Relay the endpoint registration message to the discovery service client 
        '''
        self.logger.info("registerEndpoint")
        result = self.disco.registerEndpoint(bundle)
        for res in result:
            (partName, portName, host, port) = res
            self.updatePart(partName, portName, host, port)

    def registerDevice(self, bundle):
        '''Relay the device registration message to the device interface service client
        
        '''
        typeName, args = bundle
        msg = (self.appName, self.modelName, typeName, args)
        result = self.deplc.registerDevice(msg)
        return result

    def unregisterDevice(self, bundle):
        '''Relay the device unregistration message to the device interface service client
        
        '''
        typeName, = bundle
        msg = (self.appName, self.modelName, typeName)
        result = self.deplc.unregisterDevice(msg)
        return result

    def activate(self):
        '''Activate the parts
        
        '''
        self.logger.info("activate")
        for inst in self.components:
            self.components[inst].activate()

    def deactivate(self):
        '''Deactivate the parts
        
        '''
        self.logger.info("deactivate")
        for inst in self.components:
            self.components[inst].deactivate()

    def recvChannelMessages(self, channel):
        '''Collect all messages from the channel queue and return them in a list
        '''
        msgs = []
        while True:
            try:
                msg = channel.recv(flags=zmq.NOBLOCK)
                msgs.append(msg)
            except zmq.Again:
                break
        return msgs

    def start(self):
        '''
        Start and operate the actor (infinite polling loop)
        '''
        self.logger.info("starting")
        self.discoChannel = self.disco.channel  # Private channel to the discovery service
        self.deplChannel = self.deplc.channel

        self.poller = zmq.Poller()  # Set up the poller
        self.poller.register(self.deplChannel, zmq.POLLIN)
        self.poller.register(self.discoChannel, zmq.POLLIN)
        for control in self.controls:
            self.poller.register(self.controls[control], zmq.POLLIN)

        while 1:
            sockets = dict(self.poller.poll())
            if self.discoChannel in sockets:  # If there is a message from a service, handle it
                msgs = self.recvChannelMessages(self.discoChannel)
                for msg in msgs:
                    self.handleServiceUpdate(
                        msg)  # Handle message from disco service
                del sockets[self.discoChannel]
            elif self.deplChannel in sockets:
                msgs = self.recvChannelMessages(self.deplChannel)
                for msg in msgs:
                    self.handleDeplMessage(
                        msg)  # Handle message from depl service
                del sockets[self.deplChannel]
            else:  # Handle messages from the components.
                toDelete = []
                for s in sockets:
                    if s in self.controls.values():
                        part = self.controlMap[id(s)]
                        msg = s.recv_pyobj(
                        )  # receive python object from component
                        self.handleEventReport(part, msg)  # Report event
                    toDelete += [s]
                for s in toDelete:
                    del sockets[s]

    def handleServiceUpdate(self, msgBytes):
        '''
        Handle a service update message from the discovery service
        '''
        msgUpd = disco_capnp.DiscoUpd.from_bytes(
            msgBytes)  # Parse the incoming message

        which = msgUpd.which()
        if which == 'portUpdate':
            msg = msgUpd.portUpdate
            client = msg.client
            actorHost = client.actorHost
            assert actorHost == self.globalHost  # It has to be addressed to this actor
            actorName = client.actorName
            assert actorName == self.name
            instanceName = client.instanceName
            assert instanceName in self.components  # It has to be for a part of this actor
            portName = client.portName
            scope = msg.scope
            socket = msg.socket
            host = socket.host
            port = socket.port
            if scope == "local":
                assert host == self.localHost
            self.updatePart(instanceName, portName, host,
                            port)  # Update the selected part

    def updatePart(self, instanceName, portName, host, port):
        '''
        Ask a part to update itself
        '''
        self.logger.info("updatePart %s" % str(
            (instanceName, portName, host, port)))
        part = self.components[instanceName]
        part.handlePortUpdate(portName, host, port)

    def handleDeplMessage(self, msgBytes):
        '''
        Handle a message from the deployment service
        '''
        msgUpd = deplo_capnp.DeplCmd.from_bytes(
            msgBytes)  # Parse the incoming message

        which = msgUpd.which()
        if which == 'resourceMsg':
            what = msgUpd.resourceMsg.which()
            if what == 'resCPUX':
                self.handleCPULimit()
            elif what == 'resMemX':
                self.handleMemLimit()
            elif what == 'resSpcX':
                self.handleSpcLimit()
            elif what == 'resNetX':
                self.handleNetLimit()
            else:
                self.logger.error("unknown resource msg from deplo: '%s'" %
                                  what)
                pass
        elif which == 'reinstateCmd':
            self.handleReinstate()
        elif which == 'nicStateMsg':
            stateMsg = msgUpd.nicStateMsg
            state = str(stateMsg.nicState)
            self.handleNICStateChange(state)
        elif which == 'peerInfoMsg':
            peerMsg = msgUpd.peerInfoMsg
            state = str(peerMsg.peerState)
            uuid = peerMsg.uuid
            self.handlePeerStateChange(state, uuid)
        else:
            self.logger.error("unknown msg from deplo: '%s'" % which)
            pass

    def handleReinstate(self):
        self.logger.info('handleReinstate')
        self.poller.unregister(self.discoChannel)
        self.disco.reconnect()
        self.discoChannel = self.disco.channel
        self.poller.register(self.discoChannel, zmq.POLLIN)
        for inst in self.components:
            self.components[inst].handleReinstate()

    def handleNICStateChange(self, state):
        '''
        Handle the NIC state change message: notify components   
        '''
        self.logger.info("handleNICStateChange")
        for component in self.components.values():
            component.handleNICStateChange(state)

    def handlePeerStateChange(self, state, uuid):
        '''
        Handle the peer state change message: notify components   
        '''
        self.logger.info("handlePeerStateChange")
        for component in self.components.values():
            component.handlePeerStateChange(state, uuid)

    def handleCPULimit(self):
        '''
        Handle the case when the CPU limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleCPULimit")
        for component in self.components.values():
            component.handleCPULimit()

    def handleMemLimit(self):
        '''
        Handle the case when the memory limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleMemLimit")
        for component in self.components.values():
            component.handleMemLimit()

    def handleSpcLimit(self):
        '''
        Handle the case when the file space limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleSpcLimit")
        for component in self.components.values():
            component.handleSpcLimit()

    def handleNetLimit(self):
        '''
        Handle the case when the net usage limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleNetLimit")
        for component in self.components.values():
            component.handleNetLimit()

    def handleEventReport(self, part, msg):
        '''Handle event report from a part
        
        The event report is forwarded to the deplo service. 
        '''
        partName = part.getName()
        typeName = part.getTypeName()
        bundle = (
            partName,
            typeName,
        ) + (msg, )
        self.deplc.reportEvent(bundle)

    def terminate(self):
        '''Terminate all functions of the actor. 
        
        Terminate all components, and connections to the deplo/disco services.
        Finally exit the process. 
        '''
        self.logger.info("terminating")
        for component in self.components.values():
            component.terminate()
        time.sleep(1.0)
        self.deplc.terminate()
        self.disco.terminate()
        if self.auth:
            self.auth.stop()
        # Clean up everything
        # self.context.destroy()
        # time.sleep(1.0)
        self.logger.info("terminated")
        os._exit(0)
Exemple #12
0
class RpcServer:
    """"""
    def __init__(self) -> None:
        """
        Constructor
        """
        # Save functions dict: key is function name, value is function object
        self._functions: Dict[str, Callable] = {}

        # Zmq port related
        self._context: zmq.Context = zmq.Context()

        # Reply socket (Request–reply pattern)
        self._socket_rep: zmq.Socket = self._context.socket(zmq.REP)

        # Publish socket (Publish–subscribe pattern)
        self._socket_pub: zmq.Socket = self._context.socket(zmq.PUB)

        # Worker thread related
        self._active: bool = False  # RpcServer status
        self._thread: threading.Thread = None  # RpcServer thread
        self._lock: threading.Lock = threading.Lock()

        # Heartbeat related
        self._heartbeat_at: int = None

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

    def is_active(self) -> bool:
        """"""
        return self._active

    def start(self,
              rep_address: str,
              pub_address: str,
              username: str = "",
              password: str = "",
              server_secretkey_path: str = "") -> None:
        """
        Start RpcServer
        """
        if self._active:
            return

        # Start authenticator
        if server_secretkey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                server_secretkey_path)

            self.__socket_pub.curve_secretkey = secretkey
            self.__socket_pub.curve_publickey = publickey
            self.__socket_pub.curve_server = True

            self.__socket_rep.curve_secretkey = secretkey
            self.__socket_rep.curve_publickey = publickey
            self.__socket_rep.curve_server = True
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_pub.plain_server = True
            self.__socket_rep.plain_server = True

        # Bind socket address
        self._socket_rep.bind(rep_address)
        self._socket_pub.bind(pub_address)

        # Start RpcServer status
        self._active = True

        # Start RpcServer thread
        self._thread = threading.Thread(target=self.run)
        self._thread.start()

        # Init heartbeat publish timestamp
        self._heartbeat_at = time() + HEARTBEAT_INTERVAL

    def stop(self) -> None:
        """
        Stop RpcServer
        """
        if not self._active:
            return

        # Stop RpcServer status
        self._active = False

    def join(self) -> None:
        # Wait for RpcServer thread to exit
        if self._thread and self._thread.is_alive():
            self._thread.join()
        self._thread = None

    def run(self) -> None:
        """
        Run RpcServer functions
        """
        while self._active:
            # Poll response socket for 1 second
            n: int = self._socket_rep.poll(1000)
            self.check_heartbeat()

            if not n:
                continue

            # Receive request data from Reply socket
            req = self._socket_rep.recv_pyobj()

            # Get function name and parameters
            name, args, kwargs = req

            # Try to get and execute callable function object; capture exception information if it fails
            try:
                func: Callable = self._functions[name]
                r: Any = func(*args, **kwargs)
                rep: list = [True, r]
            except Exception as e:  # noqa
                rep: list = [False, traceback.format_exc()]

            # send callable response by Reply socket
            self._socket_rep.send_pyobj(rep)

        # Unbind socket address
        self._socket_pub.unbind(self._socket_pub.LAST_ENDPOINT)
        self._socket_rep.unbind(self._socket_rep.LAST_ENDPOINT)
        if self.__authenticator:
            self.__authenticator.stop()

    def publish(self, topic: str, data: Any) -> None:
        """
        Publish data
        """
        with self._lock:
            self._socket_pub.send_pyobj([topic, data])

    def register(self, func: Callable) -> None:
        """
        Register function
        """
        self._functions[func.__name__] = func

    def check_heartbeat(self) -> None:
        """
        Check whether it is required to send heartbeat.
        """
        now: float = time()
        if now >= self._heartbeat_at:
            # Publish heartbeat
            self.publish(HEARTBEAT_TOPIC, now)

            # Update timestamp of next publish
            self._heartbeat_at = now + HEARTBEAT_INTERVAL
Exemple #13
0
  def run(self):
    self.set_status("Client Startup")
    self.set_status("Creating zmq Contexts",1)
    clientctx = zmq.Context() 
    self.set_status("Starting zmq ThreadedAuthenticator",1)
    #clientauth = zmq.auth.ThreadedAuthenticator(clientctx)
    clientauth = ThreadAuthenticator(clientctx)
    clientauth.start()
    
    with taco.globals.settings_lock:
      publicdir  = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/public/"))
      privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/private/"))

    self.set_status("Configuring Curve to use publickey dir:" + publicdir)
    clientauth.configure_curve(domain='*', location=publicdir)
    
    poller = zmq.Poller()
    while not self.stop.is_set():
      #logging.debug("PRE")
      result = self.sleep.wait(0.1)
      #logging.debug(result)
      self.sleep.clear()
      if self.stop.is_set(): break

      if abs(time.time() - self.connect_block_time) > 1:
        with taco.globals.settings_lock: self.max_upload_rate   = taco.globals.settings["Upload Limit"] * taco.constants.KB
        with taco.globals.settings_lock: self.max_download_rate = taco.globals.settings["Download Limit"] * taco.constants.KB
        self.chunk_request_rate = float(taco.constants.FILESYSTEM_CHUNK_SIZE) / float(self.max_download_rate)
        #logging.debug(str((self.max_download_rate,taco.constants.FILESYSTEM_CHUNK_SIZE,self.chunk_request_rate)))
        self.connect_block_time = time.time() 
        with taco.globals.settings_lock:
          for peer_uuid in taco.globals.settings["Peers"].keys():
            if taco.globals.settings["Peers"][peer_uuid]["enabled"]:
              #init some defaults
              if not peer_uuid in self.client_reconnect_mod: self.client_reconnect_mod[peer_uuid] = taco.constants.CLIENT_RECONNECT_MIN
              if not peer_uuid in self.client_connect_time:  self.client_connect_time[peer_uuid]  = time.time() + self.client_reconnect_mod[peer_uuid]
              if not peer_uuid in self.client_timeout:       self.client_timeout[peer_uuid]       = time.time() + taco.constants.ROLLCALL_TIMEOUT

              if time.time() >= self.client_connect_time[peer_uuid]:
                if peer_uuid not in self.clients.keys():
                  self.set_status("Starting Client for: " + peer_uuid)
                  try:
                    ip_of_client = socket.gethostbyname(taco.globals.settings["Peers"][peer_uuid]["hostname"])
                  except:
                    self.set_status("Starting of client failed due to bad dns lookup:" + peer_uuid)
                    continue
                  self.clients[peer_uuid] = clientctx.socket(zmq.DEALER)
                  self.clients[peer_uuid].setsockopt(zmq.LINGER, 0)
                  client_public, client_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-client.key_secret")))
                  self.clients[peer_uuid].curve_secretkey = client_secret
                  self.clients[peer_uuid].curve_publickey = client_public
                  self.clients[peer_uuid].curve_serverkey = str(taco.globals.settings["Peers"][peer_uuid]["serverkey"])
                  self.clients[peer_uuid].connect("tcp://" + ip_of_client + ":" + str(taco.globals.settings["Peers"][peer_uuid]["port"]))
                  self.next_rollcall[peer_uuid] = time.time()

                  with taco.globals.high_priority_output_queue_lock:   taco.globals.high_priority_output_queue[peer_uuid]   = Queue.Queue()
                  with taco.globals.medium_priority_output_queue_lock: taco.globals.medium_priority_output_queue[peer_uuid] = Queue.Queue()
                  with taco.globals.low_priority_output_queue_lock:    taco.globals.low_priority_output_queue[peer_uuid]    = Queue.Queue()
                  with taco.globals.file_request_output_queue_lock:    taco.globals.file_request_output_queue[peer_uuid]    = Queue.Queue()

                  poller.register(self.clients[peer_uuid],zmq.POLLIN)

      if len(self.clients.keys()) == 0: continue

      peer_keys = self.clients.keys()
      random.shuffle(peer_keys)
      for peer_uuid in peer_keys:
        #self.set_status("Socket Write Possible:" + peer_uuid)

        #high priority queue processing
        with taco.globals.high_priority_output_queue_lock:
          while not taco.globals.high_priority_output_queue[peer_uuid].empty():
            self.set_status("high priority output q not empty:" + peer_uuid)
            data = taco.globals.high_priority_output_queue[peer_uuid].get()
            self.clients[peer_uuid].send_multipart(['',data])
            self.sleep.set()
            with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #medium priority queue processing
        with taco.globals.medium_priority_output_queue_lock:
          while not taco.globals.medium_priority_output_queue[peer_uuid].empty():
            self.set_status("medium priority output q not empty:" + peer_uuid)
            data = taco.globals.medium_priority_output_queue[peer_uuid].get()
            self.clients[peer_uuid].send_multipart(['',data])
            self.sleep.set()
            with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #filereq q, aka the download throttle 
        if time.time() >= self.file_request_time:
          self.file_request_time = time.time() 
          with taco.globals.file_request_output_queue_lock:
            if not taco.globals.file_request_output_queue[peer_uuid].empty():
              with taco.globals.download_limiter_lock: download_rate = taco.globals.download_limiter.get_rate()

              bw_percent = download_rate / self.max_download_rate
              wait_time = self.chunk_request_rate * bw_percent
              #self.set_status(str((download_rate,self.max_download_rate,self.chunk_request_rate,bw_percent,wait_time)))
              if wait_time > 0.01: self.file_request_time += wait_time

              if download_rate < self.max_download_rate:
                self.set_status("filereq output q not empty+free bw:" + peer_uuid)
                data = taco.globals.file_request_output_queue[peer_uuid].get()
                self.clients[peer_uuid].send_multipart(['',data])
                self.sleep.set()
                with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #low priority queue processing
        with taco.globals.low_priority_output_queue_lock:
          if not taco.globals.low_priority_output_queue[peer_uuid].empty():
            with taco.globals.upload_limiter_lock: upload_rate = taco.globals.upload_limiter.get_rate()
            if upload_rate < self.max_upload_rate:
              self.set_status("low priority output q not empty+free bw:" + peer_uuid)
              data = taco.globals.low_priority_output_queue[peer_uuid].get()
              self.clients[peer_uuid].send_multipart(['',data])
              self.sleep.set()
              with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #rollcall special case
        if self.next_rollcall[peer_uuid] < time.time():
          #self.set_status("Requesting Rollcall from: " + peer_uuid)
          data = taco.commands.Request_Rollcall()
          self.clients[peer_uuid].send_multipart(['',data])
          with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))
          self.next_rollcall[peer_uuid] = time.time() + random.randint(taco.constants.ROLLCALL_MIN,taco.constants.ROLLCALL_MAX)
          self.sleep.set()
          #continue

        #RECEIVE BLOCK
        socks = dict(poller.poll(0))
        while self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLIN:
          #self.set_status("Socket Read Possible")
          sink,data = self.clients[peer_uuid].recv_multipart()
          with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data))
          self.set_client_last_reply(peer_uuid)
          self.next_request = taco.commands.Process_Reply(peer_uuid,data)
          if self.next_request != "":
            with taco.globals.medium_priority_output_queue_lock:
              taco.globals.medium_priority_output_queue[peer_uuid].put(self.next_request)
          self.sleep.set()
          socks = dict(poller.poll(0))

        #cleanup block
        self.error_msg = []
        if self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLERR: self.error_msg.append("got a socket error")
        if abs(self.client_timeout[peer_uuid] - time.time()) > taco.constants.ROLLCALL_TIMEOUT: self.error_msg.append("havn't seen communications")

        if len(self.error_msg) > 0:
          self.set_status("Stopping client: " + peer_uuid + " -- " + " and ".join(self.error_msg),2)
          poller.unregister(self.clients[peer_uuid])
          self.clients[peer_uuid].close(0)
          del self.clients[peer_uuid]          
          del self.client_timeout[peer_uuid]
          with taco.globals.high_priority_output_queue_lock:    del taco.globals.high_priority_output_queue[peer_uuid]
          with taco.globals.medium_priority_output_queue_lock:  del taco.globals.medium_priority_output_queue[peer_uuid]
          with taco.globals.low_priority_output_queue_lock:     del taco.globals.low_priority_output_queue[peer_uuid]
          with taco.globals.file_request_output_queue_lock:     del taco.globals.file_request_output_queue[peer_uuid]
          self.client_reconnect_mod[peer_uuid] = min(self.client_reconnect_mod[peer_uuid] + taco.constants.CLIENT_RECONNECT_MOD,taco.constants.CLIENT_RECONNECT_MAX)
          self.client_connect_time[peer_uuid] = time.time() + self.client_reconnect_mod[peer_uuid]
          

        
    self.set_status("Terminating Clients")
    for peer_uuid in self.clients.keys():
      self.clients[peer_uuid].close(0)
    self.set_status("Stopping zmq ThreadedAuthenticator")
    clientauth.stop() 
    clientctx.term()
    self.set_status("Clients Exit")    
Exemple #14
0
class CombaZMQAdapter(threading.Thread, CombaBase):
    
    def __init__(self, port):

        self.port = str(port)
        threading.Thread.__init__ (self)
        self.shutdown_event = Event()
        self.context = zmq.Context().instance()
        self.authserver = ThreadAuthenticator(self.context)
        self.loadConfig()
        self.start()

    #------------------------------------------------------------------------------------------#
    def run(self):
        """
        run runs on function start
        """

        self.startAuthserver()
        self.data = ''
        self.socket = self.context.socket(zmq.REP)
        self.socket.plain_server = True
        self.socket.bind("tcp://*:"+self.port)
        self.shutdown_event.clear()
        self.controller = CombaController(self, self.lqs_socket, self.lqs_recorder_socket)
        self.controller.messenger.setMailAddresses(self.get('frommail'), self.get('adminmail'))
        self.can_send = False
        # Process tasks forever
        while not self.shutdown_event.is_set():
            self.data = self.socket.recv()
            self.can_send = True
            data = self.data.split(' ')
            command = str(data.pop(0)) 
            params = "()" if len(data) < 1 else  "('" + "','".join(data) + "')" 
                     
            try: 
                exec"a=self.controller." + command + params  
            
            except SyntaxError:                
                self.controller.message('Warning: Syntax Error')

            except AttributeError:
                print "Warning: Method " + command + " does not exist"
                self.controller.message('Warning: Method ' + command + ' does not exist')
            except TypeError:
                print "Warning: Wrong number of params"
                self.controller.message('Warning: Wrong number of params')
            except:
                print "Warning: Unknown Error"
                self.controller.message('Warning: Unknown Error')

        return

    #------------------------------------------------------------------------------------------#
    def halt(self):
        """
        Stop the server
        """
        if self.shutdown_event.is_set():
            return
        try:
            del self.controller
        except:
            pass
        self.shutdown_event.set()
        result = 'failed'
        try:
            result = self.socket.unbind("tcp://*:"+self.port)
        except:
            pass
        #self.socket.close()

    #------------------------------------------------------------------------------------------#
    def reload(self):
        """
        stop, reload config and startagaing
        """
        if self.shutdown_event.is_set():
            return
        self.loadConfig()
        self.halt()
        time.sleep(3)
        self.run()

    #------------------------------------------------------------------------------------------#
    def send(self,message):
        """
        Send a message to the client
        :param message: string
        """
        if self.can_send:
            self.socket.send(message)
            self.can_send = False

    #------------------------------------------------------------------------------------------#
    def startAuthserver(self):
        """
        Start zmq authentification server
        """
        # stop auth server if running
        if self.authserver.is_alive():
            self.authserver.stop()
        if self.securitylevel > 0:
            # Authentifizierungsserver starten.

            self.authserver.start()

            # Bei security level 2 auch passwort und usernamen verlangen
            if self.securitylevel > 1:
                try:

                    addresses = CombaWhitelist().getList()
                    for address in addresses:
                        self.authserver.allow(address)

                except:
                    pass

            # Instruct authenticator to handle PLAIN requests
            self.authserver.configure_plain(domain='*', passwords=self.getAccounts())

    #------------------------------------------------------------------------------------------#
    def getAccounts(self):
        """
        Get accounts from redis db
        :return: llist - a list of accounts
        """
        accounts = CombaUser().getLogins()
        db = redis.Redis()

        internaccount = db.get('internAccess')
        if not internaccount:
            user = ''.join(random.sample(string.lowercase,10))
            password = ''.join(random.sample(string.lowercase+string.uppercase+string.digits,22))
            db.set('internAccess', user + ':' + password)
            intern = [user, password]
        else:
            intern =  internaccount.split(':')

        accounts[intern[0]] = intern[1]

        return accounts
Exemple #15
0
class Facilities(furnace_runtime.FurnaceRuntime):
    """
    Main Facilities class.
    """
    def __init__(self, debug, u2, kp, qt, it, ip, fac):
        """
        Constructor, ZMQ connections to the APP partition and backend are built here.
        :param u2: The UUID for the corresponding app instance.
        :param kp: The keypair to use, in the form {'be_key': '[path]', 'app_key': '[path]'}
        :param qt: Custom quotas dict.
        :param ip: Connect to this proxy IP.
        :param it: Connect to this proxy TCP port.
        :param fac: The path to the FAC partition's ZMQ socket.
        """

        super(Facilities, self).__init__(debug=debug)
        self.debug = debug
        self.u2 = u2
        self.fac_path = fac
        self.it = it
        self.ip = ip
        self.already_shutdown = False

        self.tprint(DEBUG, "starting Facilities constructor")
        self.tprint(DEBUG, f"running as UUID={u2}")

        api = api_implementation._default_implementation_type
        if api == "python":
            raise Exception("not running in cpp accelerated mode")
        self.tprint(DEBUG, f"running in api mode: {api}")

        self.fmsg_in = facilities_pb2.FacMessage()
        self.fmsg_out = facilities_pb2.FacMessage()
        self.pmsg_in = facilities_pb2.FacMessage()
        self.pmsg_out = facilities_pb2.FacMessage()

        self.storage = {}
        self.io_id = 0
        self.disk_usage = 0
        self.max_disk_usage = int(qt["_furnace_max_disk"])
        self.tprint(DEBUG, f"MAX_DISK_USAGE {self.max_disk_usage}")

        self.context = zmq.Context(io_threads=4)
        self.poller = zmq.Poller()

        # connections with the proxy
        self.tprint(DEBUG, f"loading crypto...")
        self.auth = ThreadAuthenticator(self.context)  # crypto bootstrap
        self.auth.start()
        self.auth.configure_curve(domain="*",
                                  location=zmq.auth.CURVE_ALLOW_ANY)
        publisher_public, _ = zmq.auth.load_certificate(kp["be_pub"])
        public, secret = zmq.auth.load_certificate(kp["app_key"])

        assert self.auth.is_alive()
        self.tprint(DEBUG, f"loading crypto... DONE")

        TCP_PXY_RTR = f"tcp://{str(ip)}:{int(it)+0}"
        TCP_PXY_PUB = f"tcp://{str(ip)}:{int(it)+1}"

        # connections to the proxy
        s = f"tcp://{self.ip}:{self.it}"
        self.dealer_pxy = self.context.socket(zmq.DEALER)  # async directed
        self.didentity = f"{u2}-d"  # this must match the FE
        self.dealer_pxy.identity = self.didentity.encode()
        self.dealer_pxy.curve_publickey = public
        self.dealer_pxy.curve_secretkey = secret
        self.dealer_pxy.curve_serverkey = publisher_public
        self.tprint(INFO, f"FAC--PXY: Connecting as Dealer to {TCP_PXY_RTR}")
        self.dealer_pxy.connect(s)

        self.req_pxy = self.context.socket(zmq.REQ)  # sync
        self.ridentity = f"{u2}-r"
        self.req_pxy.identity = self.ridentity.encode()
        self.req_pxy.curve_publickey = public
        self.req_pxy.curve_secretkey = secret
        self.req_pxy.curve_serverkey = publisher_public
        self.req_pxy.connect(s)

        self.sub_pxy = self.context.socket(zmq.SUB)  # async broadcast
        self.sub_pxy.curve_publickey = public
        self.sub_pxy.curve_secretkey = secret
        self.sub_pxy.curve_serverkey = publisher_public
        self.sub_pxy.setsockopt(zmq.SUBSCRIBE, b"")
        self.tprint(INFO,
                    f"FAC--PXY: Connecting as Publisher to {TCP_PXY_PUB}")
        self.sub_pxy.connect(TCP_PXY_PUB)

        self.poller.register(self.dealer_pxy, zmq.POLLIN)
        self.poller.register(self.sub_pxy, zmq.POLLIN)

        # connections with the frontend
        s = f"ipc://{self.fac_path}{os.sep}fac_rtr"
        self.rtr_fac = self.context.socket(zmq.ROUTER)  # fe sync/async msgs
        self.tprint(INFO, f"APP--FAC: Binding as Router to {s}")
        try:
            self.rtr_fac.bind(s)
        except zmq.error.ZMQError:
            self.tprint(ERROR, f"rtr_fac {s} is already in use!")
            sys.exit()

        s = f"ipc://{self.fac_path}{os.sep}fac_pub"
        self.pub_fac = self.context.socket(zmq.PUB)  # bcast to fe
        self.tprint(INFO, f"APP--FAC: Binding as Publisher to {s}")
        try:
            self.pub_fac.bind(s)
        except zmq.error.ZMQError:
            self.tprint(ERROR, f"pub_fac {s} is already in use!")
            sys.exit()

        self.poller.register(self.rtr_fac, zmq.POLLIN)
        self.tprint(DEBUG, "leaving Facilities constructor")

    def loop(self):
        """
        Main event loop.  Polls (with timeout) on ZMQ sockets waiting for data from the APP partition and backend.
        :returns: Nothing.
        """

        while True:
            socks = dict(self.poller.poll(TIMEOUT_FAC))

            # inbound async from BE (sync should not arrive here)
            if self.dealer_pxy in socks:
                sys.stdout.write("v")
                sync = ASYNC
                raw_msg = self.dealer_pxy.recv()
                self.msgin += 1
                raw_out = self.be2fe_process_protobuf(raw_msg, sync)

            # inbound broadcast from BE
            if self.sub_pxy in socks:
                sys.stdout.write("v")
                raw_msg = self.sub_pxy.recv()
                self.msgin += 1
                self.pmsg_in.ParseFromString(raw_msg)
                self.tprint(DEBUG, "BROADCAST FROM PXY\n%s" % (self.pmsg_in, ))
                index = 0
                submsg_list = getattr(self.pmsg_in, "be_msg")
                m = submsg_list[index]

                self.pmsg_out.Clear()
                msgnum = self.pmsg_out.__getattribute__("BE_MSG")
                self.pmsg_out.type.append(msgnum)
                submsg = getattr(self.pmsg_out, "be_msg").add()
                submsg.status = m.status
                submsg.value = m.value
                raw_str = self.pmsg_out.SerializeToString()
                self.tprint(DEBUG, "to FE {len(raw_str)}B {self.pmsg_out}")
                self.pub_fac.send(raw_str)
                self.msgout += 1

            # inbound from FE
            if self.rtr_fac in socks:
                sys.stdout.write("^")
                pkt = self.rtr_fac.recv_multipart()
                self.msgin += 1
                if len(pkt) == 3:  # sync
                    sync = SYNC
                    ident, empty, raw_msg = pkt
                    raw_out = self.fe2be_process_protobuf(raw_msg, sync)
                    self.rtr_fac.send_multipart(
                        [ident.encode(), b"",
                         raw_out.encode()])
                    self.msgout += 1
                elif len(pkt) == 2:  # async
                    sync = ASYNC
                    ident, raw_msg = pkt
                    self.fe2be_process_protobuf(raw_msg, sync)

            if not socks:
                sys.stdout.write(".")
            sys.stdout.flush()

            self.tick += 1

    def be2fe_process_protobuf(self, raw_msg, sync):
        """
        Processes messages sent by the backend.
        :returns: Nothing (modifies class's protobufs).
        """
        self.pmsg_in.ParseFromString(raw_msg)
        self.tprint(DEBUG, "FROM BE\n%s" % (self.pmsg_in, ))

        for t in self.pmsg_in.type:
            msgtype = facilities_pb2.FacMessage.Type.Name(
                t)  # msgtype = BE_MSG
            msglist = getattr(self.pmsg_in, msgtype.lower())
            m = msglist.pop(0)  # the particular message
            if msgtype == "BE_MSG":
                self.fmsg_out.Clear()
                msgnum = self.fmsg_out.__getattribute__("BE_MSG")
                self.fmsg_out.type.append(msgnum)
                submsg = getattr(self.fmsg_out, "be_msg").add()
                submsg.status = m.status
                submsg.value = m.value
                raw_str = self.fmsg_out.SerializeToString()
                self.tprint(
                    DEBUG,
                    "to FE {self.didentity} {len(raw_str)}B {self.fmsg_out}")
                self.rtr_fac.send_multipart([self.didentity.encode(), raw_str])
                self.msgout += 1

        return

    # Split off between pass-through to BE or fac-provided IO.
    def fe2be_process_protobuf(self, raw_msg, sync):
        """
        Processes messages sent by the app.  Some messages are processed locally (IO).
        Others are forwarded to the backend.
        :returns: Raw serialized protobuf.
        """
        self.fmsg_in.ParseFromString(raw_msg)
        self.tprint(DEBUG, "FROM FE\n%s" % (self.fmsg_in, ))

        self.fmsg_out.Clear()
        for t in self.fmsg_in.type:
            msgtype = facilities_pb2.FacMessage.Type.Name(
                t)  # msgtype = BE_MSG
            msglist = getattr(self.fmsg_in, msgtype.lower())
            m = msglist.pop(0)  # the particular message
            if msgtype == "BE_MSG":  # Pass-through to BE
                self.fe2be_process_msg(m, sync)
            elif msgtype == "GET":  # IO
                self.process_get(m)
            elif msgtype == "SET":  # IO
                self.process_set(m)

        self.tprint(DEBUG, "sending {self.fmsg_out}")
        raw_out = self.fmsg_out.SerializeToString()
        return raw_out

    def fe2be_process_msg(self, m, sync):
        """
        Processes messages sent by the app.  All these are destined to the backend.
        :returns: Nothing (modifies class's protobufs).
        """
        self.pmsg_out.Clear()
        msgnum = self.pmsg_out.__getattribute__("BE_MSG")
        self.pmsg_out.type.append(msgnum)
        submsg = getattr(self.pmsg_out, "be_msg").add()
        submsg.status = m.status
        submsg.value = m.value
        raw_str = self.pmsg_out.SerializeToString()
        self.tprint(DEBUG,
                    "sending to BE %dB %s" % (len(raw_str), self.pmsg_out))

        if sync == SYNC:
            self.req_pxy.send(raw_str)
            self.msgout += 1
            self.pmsg_in.ParseFromString(self.req_pxy.recv())
            self.msgin += 1
            self.tprint(DEBUG,
                        "recv from BE %dB %s" % (len(raw_str), self.pmsg_in))

            index = 0
            submsg_list = getattr(self.pmsg_in, "be_msg_ret")
            status = submsg_list[index].status
            val = submsg_list[index].value

            self.fmsg_out.Clear()
            msgnum = self.fmsg_out.__getattribute__("BE_MSG_RET")
            self.fmsg_out.type.append(msgnum)
            submsg = getattr(self.fmsg_out, "be_msg_ret").add()
            submsg.status = status
            submsg.value = val
            # raw_str = self.pmsg_out.SerializeToString()
            # return raw_str
            return

        elif sync == ASYNC:
            self.dealer_pxy.send(raw_str)
            self.msgout += 1
            return

    def process_get(self, m):
        """
        The app is asking for previously stored data.
        :param m: The protobuf message sent by the app.
        :returns: Nothing (modifies class's protobufs).
        """
        key = str(m.key)
        msgtype = "GET_RET"
        msgnum = self.fmsg_out.__getattribute__(msgtype)
        self.fmsg_out.type.append(msgnum)
        submsg = getattr(self.fmsg_out, msgtype.lower()).add()
        submsg.key = key
        if not self.get_name_ok(key):
            submsg.value = ""
            submsg.result = VMI_FAILURE
            return
        try:
            io_id = self.storage[key]
            with open(FAC_IO_DIR + str(io_id), "rb") as fh:
                submsg.value = fh.read().decode()
            submsg.result = VMI_SUCCESS
        except IndexError:
            submsg.value = ""
            submsg.result = VMI_FAILURE
        return

    def process_set(self, m):
        """
        The app has requested we store this data.
        :param m: The protobuf message sent by the app.
        :returns: Nothing (modifies class's protobufs).
        """
        key = str(m.key)
        value = str(m.value)
        msgtype = "SET_RET"
        msgnum = self.fmsg_out.__getattribute__(msgtype)
        self.fmsg_out.type.append(msgnum)
        submsg = getattr(self.fmsg_out, msgtype.lower()).add()
        submsg.key = key
        if not self.get_name_ok(key):
            self.tprint(ERROR, f"key not allowed")
            submsg.value = ""
            submsg.result = VMI_FAILURE
            return

        self.tprint(
            DEBUG,
            f"cur={self.disk_usage}, req={len(value.encode())}, max={MAX_DISK_USAGE}",
        )
        if self.disk_usage + len(value.encode()) > MAX_DISK_USAGE:
            self.tprint(ERROR, f"disk usage exceeded")
            submsg.result = VMI_FAILURE
            return

        try:
            # io_id = str(uuid.uuid4())

            self.tprint(DEBUG, f"checking for existing key {key}")
            # overwrite a key
            try:
                existing_id = self.storage[key]
                self.tprint(DEBUG, f"found!")
            except KeyError:
                self.tprint(DEBUG, f"not found")
                pass
            else:
                s = os.path.getsize(FAC_IO_DIR + str(existing_id))
                self.tprint(DEBUG, f"deleting existing file of size={s}")
                self.disk_usage -= s
                os.remove(FAC_IO_DIR + str(existing_id))
                del self.storage[key]

            io_id = self.io_id
            self.io_id += 1
            self.tprint(DEBUG, f"will assign key {io_id}")
            self.tprint(DEBUG, f"writing file of size {len(value.encode())}")
            self.disk_usage += len(value)
            self.tprint(DEBUG, f"disk usage at {self.disk_usage}")
            self.storage[key] = io_id
            with open(FAC_IO_DIR + str(io_id), "wb") as fh:
                fh.write(value.encode())
            submsg.result = VMI_SUCCESS
        except:
            submsg.result = VMI_FAILURE
        return

    def get_name_ok(self, n):
        """
        :param n: The key (must be alphanum and <=16 characters).
        :returns: True if in accordance with naming convention, False if otherwise.
        """
        return 0 < len(n) < 17 and re.match("^[\w-]+$", n)

    def check_encoding(self, val):
        """
        Fun.  Now with more Unicode.
        :param val: Is this variable a string?
        :returns: Nothing.
        :raises: TypeError if val is not a string.
        """
        if not isinstance(val, str):
            raise TypeError("Error: value is not of type str")

    def shutdown(self):
        """
        Exit cleanly.
        :returns: Nothing.
        """
        if self.already_shutdown:
            self.tprint(ERROR, "already shutdown?")
            return

        self.already_shutdown = True
        for key in self.storage:
            io_id = self.storage[key]
            self.tprint(INFO, "removing io: %s -> %s" % (key, io_id))
            os.remove(FAC_IO_DIR + str(io_id))

        self.auth.stop()

        self.poller.unregister(self.dealer_pxy)
        self.poller.unregister(self.sub_pxy)
        self.dealer_pxy.close()
        self.sub_pxy.close()

        self.poller.unregister(self.rtr_fac)
        self.rtr_fac.close()
        self.pub_fac.close()

        self.context.destroy()
        super(Facilities, self).shutdown()
Exemple #16
0
def main():
    """
    Runs SEND either in transmitter or receiver mode

    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-t",
        "--transmit",
        action="store_true",
        help="Flag indicating that user will be transmitting files"
    )
    parser.add_argument(
        "-r",
        "--receive",
        action="store_true",
        help="Flag indicating that user will be receiving files"
    )
    parser.add_argument(
        "--location",
        help="Location of files to send/receive. Can be a specific file if tx."
    )
    parser.add_argument(
        "--ip",
        help="IP Address to form connection with"
    )
    parser.add_argument(
        "--port",
        nargs='?',
        const=6000,
        default=6000,
        type=int,
        help="Port to form connection with (only needed if using non-default)"
    )
    parser.add_argument(
        "--public_key",
        nargs='?',
        help="Public Key of transmitter in plain-text (only needed if receiver)"
    )
    args=parser.parse_args()

    # Security Authentication Thread
    _generate_security_keys()
    authenticator = ThreadAuthenticator(manager.ctx)
    authenticator.start()
    whitelist = [
        "127.0.0.1",
        args.ip
    ]
    authenticator.allow(*whitelist)
    authenticator.configure_curve(domain="*", location=PUBKEYS)

    try:
        if args.transmit:
            thread = manager.publish_folder(
                args.port,
                args.location
            )

        elif args.receive:
            thread = manager.subscribe_folder(
                args.ip,
                args.port,
                args.location,
                args.public_key
            )
        else:
            raise ValueError(f"User did not specify transmit/receive")

    except (OSError, ValueError):
        raise

    except KeyboardInterrupt:
        pass

    finally:
        # Keep things rolling until the transfer is done or the thread dies from
        # timing out
        while thread.isAlive():
            pass
        thread.join()
        # Clean up and close everything out
        authenticator.stop()
        # Use destroy versus term: https://github.com/zeromq/pyzmq/issues/991
        manager.ctx.destroy()
Exemple #17
0
class MultiNodeAgent(BEMOSSAgent):
    def __init__(self, *args, **kwargs):
        super(MultiNodeAgent, self).__init__(*args, **kwargs)
        self.multinode_status = dict()
        self.getMultinodeData()
        self.agent_id = 'multinodeagent'
        self.is_parent = False
        self.last_sync_with_parent = datetime(1991, 1,
                                              1)  #equivalent to -ve infinitive
        self.parent_node = None
        self.recently_online_node_list = []  # initialize to lists to empty
        self.recently_offline_node_list = [
        ]  # they will be filled as nodes are discovered to be online/offline
        self.setup()

        self.runPeriodically(self.send_heartbeat, 20)
        self.runPeriodically(self.check_health, 60, start_immediately=False)
        self.runPeriodically(self.sync_all_with_parent, 600)
        self.subscribe('relay_message', self.relayDirectMessage)
        self.subscribe('update_multinode_data', self.updateMultinodeData)
        self.runContinuously(self.pollClients)
        self.run()

    def getMultinodeData(self):
        self.multinode_data = db_helper.get_multinode_data()

        self.nodelist_dict = {
            node['name']: node
            for node in self.multinode_data['known_nodes']
        }
        self.node_name_list = [
            node['name'] for node in self.multinode_data['known_nodes']
        ]
        self.address_list = [
            node['address'] for node in self.multinode_data['known_nodes']
        ]
        self.server_key_list = [
            node['server_key'] for node in self.multinode_data['known_nodes']
        ]
        self.node_name = self.multinode_data['this_node']

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == self.node_name:
                self.node_index = index
                break
        else:
            raise ValueError(
                '"this_node:" entry on the multinode_data json file is invalid'
            )

        for node_name in self.node_name_list:  #initialize all nodes data
            if node_name not in self.multinode_status:  #initialize new nodes. There could be already the node if this getMultiNode
                # data is being called later
                self.multinode_status[node_name] = dict()
                self.multinode_status[node_name][
                    'health'] = -10  #initialized; never online/offline
                self.multinode_status[node_name]['last_sync_time'] = datetime(
                    1991, 1, 1)
                self.multinode_status[node_name]['last_online_time'] = None
                self.multinode_status[node_name]['last_offline_time'] = None
                self.multinode_status[node_name]['last_scanned_time'] = None

    def setup(self):
        print "Setup"

        base_dir = settings.PROJECT_DIR + "/"
        public_keys_dir = os.path.abspath(os.path.join(base_dir,
                                                       'public_keys'))
        secret_keys_dir = os.path.abspath(
            os.path.join(base_dir, 'private_keys'))

        self.secret_keys_dir = secret_keys_dir
        self.public_keys_dir = public_keys_dir

        if not (os.path.exists(public_keys_dir)
                and os.path.exists(secret_keys_dir)):
            logging.critical(
                "Certificates are missing - run generate_certificates.py script first"
            )
            sys.exit(1)

        ctx = zmq.Context.instance()
        self.ctx = ctx
        # Start an authenticator for this context.
        self.auth = ThreadAuthenticator(ctx)
        self.auth.start()
        self.configure_authenticator()

        server = ctx.socket(zmq.PUB)

        server_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['server_secret_key']
        server_secret_file = os.path.join(secret_keys_dir,
                                          server_secret_key_filename)
        server_public, server_secret = zmq.auth.load_certificate(
            server_secret_file)
        server.curve_secretkey = server_secret
        server.curve_publickey = server_public
        server.curve_server = True  # must come before bind
        server.bind(
            self.multinode_data['known_nodes'][self.node_index]['address'])
        self.server = server
        self.configureClient()

    def configure_authenticator(self):
        self.auth.allow()
        # Tell authenticator to use the certificate in a directory
        self.auth.configure_curve(domain='*', location=self.public_keys_dir)

    def disperseMessage(self, sender, topic, message):
        for node_name in self.node_name_list:
            if node_name == self.node_name:
                continue
            self.server.send(
                jsonify(sender, node_name + '/republish/' + topic, message))

    def republishToParent(self, sender, topic, message):
        if self.is_parent:
            return  #if I am parent, the message is already published
        for node_name in self.node_name_list:
            if self.multinode_status[node_name][
                    'health'] == 2:  #health = 2 is the parent node
                self.server.send(
                    jsonify(sender, node_name + '/republish/' + topic,
                            message))

    def sync_node_with_parent(self, node_name):
        if self.is_parent:
            print "Syncing " + node_name
            self.last_sync_with_parent = datetime.now()
            sync_date_string = self.last_sync_with_parent.strftime(
                '%B %d, %Y, %H:%M:%S')
            # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path)
            # with open(self.self_database_dump_path, 'r') as f:
            #     file_content = f.read()
            # msg = {'database_dump': base64.b64encode(file_content)}
            self.server.send(
                jsonify(
                    self.node_name, node_name + '/sync-with-parent/' +
                    sync_date_string + '/' + self.node_name, ""))

    def sync_all_with_parent(self, dbcon):

        if self.is_parent:
            self.last_sync_with_parent = datetime.now()
            sync_date_string = self.last_sync_with_parent.strftime(
                '%B %d, %Y, %H:%M:%S')
            print "Syncing all nodes"
            for node_name in self.node_name_list:
                if node_name == self.node_name:
                    continue
                # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path)
                # with open(self.self_database_dump_path, 'r') as f:
                #     file_content = f.read()
                # msg = {'database_dump': base64.b64encode(file_content)}
                self.server.send(
                    jsonify(
                        self.node_name, node_name + '/sync-with-parent/' +
                        sync_date_string + '/' + self.node_name, ""))

    def send_heartbeat(self, dbcon):
        #self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'})
        #print 'publishing'
        print "Sending heartbeat"

        last_sync_string = self.last_sync_with_parent.strftime(
            '%B %d, %Y, %H:%M:%S')
        self.server.send(
            jsonify(
                self.node_name, 'heartbeat/' + self.node_name + '/' +
                str(self.is_parent) + '/' + last_sync_string, ""))

    def extract_ip(self, addr):
        return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})',
                         addr).groups()[0]

    def getNodeId(self, node_name):

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == node_name:
                node_index = index
                break
        else:
            raise ValueError('the node name: ' + node_name +
                             ' is not found in multinode data')

        return node_index

    def getNodeName(self, node_id):
        return self.multinode_data['known_nodes'][node_id]['name']

    def handle_offline_nodes(self, dbcon, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that node on this node
            command_group = []
            for node_name in node_name_list:
                node_id = self.getNodeId(node_name)
                #put the offline event into cassandra events log table, and also create notification
                self.EventRegister(dbcon,
                                   'node-offline',
                                   reason='communication-error',
                                   source=node_name)
                # get a list of agents that were supposedly running in that offline node
                dbcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))

                if dbcon.rowcount:
                    agent_ids = dbcon.fetchall()

                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.AGENT_STATUS] = 'start'
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY
                        command_group += [message]
                        dbcon.execute(
                            "UPDATE " + node_devices_table +
                            " SET current_node_id=(%s), date_move=(%s)"
                            " WHERE agent_id=(%s)",
                            (self.node_index, datetime.now(
                                pytz.UTC), agent_id[0]))
                        dbcon.commit()
            print "moving agents from offline node to parent: " + str(
                node_name_list)
            print command_group
            if command_group:
                self.bemoss_publish(target='networkagent',
                                    topic='status_change',
                                    message=command_group)

    def handle_online_nodes(self, dbcon, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that nodes back on them
            command_group = []
            for node_name in node_name_list:

                node_id = self.getNodeId(node_name)
                if self.node_index == node_id:
                    continue  #don't handle self-online
                self.EventRegister(dbcon,
                                   'node-online',
                                   reason='communication-restored',
                                   source=node_name)

                #get a list of agents that were supposed to be running in that online node
                dbcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))
                if dbcon.rowcount:
                    agent_ids = dbcon.fetchall()
                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.
                                AGENT_STATUS] = 'stop'  #stop in this node
                        command_group += [message]
                        message = dict(message)  #create another copy
                        message[STATUS_CHANGE.NODE] = str(node_id)
                        message[
                            STATUS_CHANGE.
                            AGENT_STATUS] = 'start'  #start in the target node
                        command_group += [message]
                        #immediately update the multnode device assignment table
                        dbcon.execute(
                            "UPDATE " + node_devices_table +
                            " SET current_node_id=(%s), date_move=(%s)"
                            " WHERE agent_id=(%s)",
                            (node_id, datetime.now(pytz.UTC), agent_id[0]))
                        dbcon.commit()

            print "Moving agents back to the online node: " + str(
                node_name_list)
            print command_group

            if command_group:
                self.bemoss_publish(target='networkagent',
                                    topic='status_change',
                                    message=command_group)

    def updateParent(self, dbcon, parent_node_name):
        parent_ip = self.extract_ip(
            self.nodelist_dict[parent_node_name]['address'])
        write_new = False
        if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE
                              ):  # but parent file doesn't exists
            write_new = True
        else:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f:
                read_ip = f.read()
            if read_ip != parent_ip:
                write_new = True
        if write_new:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f:
                f.write(parent_ip)
            if dbcon:
                dbcon.close()  #close old connection
            dbcon = db_helper.db_connection(
            )  #start new connection using new parent_ip
            self.updateMultinodeData(sender=self.name,
                                     topic='update_parent',
                                     message="")

    def check_health(self, dbcon):

        for node_name, node in self.multinode_status.items():
            if node['health'] > 0:  #initialize all online nodes to 0. If they are really online, they should change it
                #  back to 1 or 2 (parent) within 30 seconds throught the heartbeat.
                node['health'] = 0

        time.sleep(30)
        parent_node_name = None  #initialize parent node
        online_node_exists = False
        for node_name, node in self.multinode_status.items():
            node['last_scanned_time'] = datetime.now()
            if node['health'] == 0:
                node['health'] = -1
                node['last_offline_time'] = datetime.now()
                self.recently_offline_node_list += [node_name]
            elif node['health'] == -1:  #offline since long
                pass
            elif node[
                    'health'] == -10:  #The node was initialized to -10, and never came online. Treat it as recently going
                # offline for this iteration so that the agents that were supposed to be running there can be migrated
                node['health'] = -1
                self.recently_offline_node_list += [node_name]
            elif node['health'] == 2:  #there is some parent node present
                parent_node_name = node_name
            if node['health'] > 0:
                online_node_exists = True  #At-least one node (itself) should be online, if not some problem

        assert online_node_exists, "At least one node (current node) must be online"

        if not parent_node_name:  #parent node doesn't exist
            #find a suitable node to elect a parent. The node with latest update from previous parent wins. If there is
            #tie, then the node coming earlier in the node-list on multinode data wins

            online_node_last_sync = dict(
            )  #only the online nodes, and their last-sync-times
            for node_name, node in self.multinode_status.items(
            ):  #copy only the online nodes
                if node['health'] > 0:
                    online_node_last_sync[node_name] = node['last_sync_time']

            latest_node = max(online_node_last_sync,
                              key=online_node_last_sync.get)
            latest_sync_date = online_node_last_sync[latest_node]

            for node_name in self.node_name_list:
                if self.multinode_status[node_name][
                        'health'] <= 0:  #dead nodes can't be parents
                    continue
                if self.multinode_status[node_name][
                        'last_sync_time'] == latest_sync_date:  # this is the first node with the latest update from parent
                    #elligible parent found
                    self.updateParent(dbcon, node_name)

                    if node_name == self.node_name:  # I am the node, so I get to become the parent
                        self.is_parent = True
                        print "I am the boss now, " + self.node_name
                        break
                    else:  #I-am-not-the-first-node with latest update; somebody else is
                        self.is_parent = False
                        break
        else:  #parent node exist
            self.updateParent(dbcon, parent_node_name)

        for node in self.multinode_data['known_nodes']:
            print node['name'] + ': ' + str(
                self.multinode_status[node['name']]['health'])

        if self.is_parent:
            #if this is a parent node, update the node_info table
            if dbcon is None:  #if no database connection exists make connection
                dbcon = db_helper.db_connection()

            tbl_node_info = settings.DATABASES['default']['TABLE_node_info']
            dbcon.execute('select node_id from ' + tbl_node_info)
            to_be_deleted_node_ids = dbcon.fetchall()
            for index, node in enumerate(self.multinode_data['known_nodes']):
                if (index, ) in to_be_deleted_node_ids:
                    to_be_deleted_node_ids.remove(
                        (index, ))  #don't remove this current node
                result = dbcon.execute(
                    'select * from ' + tbl_node_info + ' where node_id=%s',
                    (index, ))
                node_type = 'parent' if self.multinode_status[
                    node['name']]['health'] == 2 else "child"
                node_status = "ONLINE" if self.multinode_status[
                    node['name']]['health'] > 0 else "OFFLINE"
                ip_address = self.extract_ip(node['address'])
                last_scanned_time = self.multinode_status[
                    node['name']]['last_online_time']
                last_offline_time = self.multinode_status[
                    node['name']]['last_offline_time']
                last_sync_time = self.multinode_status[
                    node['name']]['last_sync_time']

                var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)"
                value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)"
                actual_values_list = (index, node['name'], node_type,
                                      node_status, ip_address,
                                      last_scanned_time, last_offline_time,
                                      last_sync_time)

                if dbcon.rowcount == 0:
                    dbcon.execute(
                        "insert into " + tbl_node_info + " " + var_list +
                        " VALUES" + value_placeholder_list, actual_values_list)
                else:
                    dbcon.execute(
                        "update " + tbl_node_info + " SET " + var_list +
                        " = " + value_placeholder_list + " where node_id = %s",
                        actual_values_list + (index, ))
            dbcon.commit()

            for id in to_be_deleted_node_ids:
                dbcon.execute(
                    'delete from accounts_userprofile_nodes where nodeinfo_id=%s',
                    id)  #delete entries in user-profile for the old node
                dbcon.commit()
                dbcon.execute('delete from ' + tbl_node_info +
                              ' where node_id=%s', id)  #delete the old nodes
                dbcon.commit()

            if self.recently_online_node_list:  #Online nodes should be handled first because, the same node can first be
                #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after
                #coming online
                self.handle_online_nodes(dbcon, self.recently_online_node_list)
                self.recently_online_node_list = []  # reset after handling
            if self.recently_offline_node_list:
                self.handle_offline_nodes(dbcon,
                                          self.recently_offline_node_list)
                self.recently_offline_node_list = []  # reset after handling

    def connect_client(self, node):
        server_public_file = os.path.join(self.public_keys_dir,
                                          node['server_key'])
        server_public, _ = zmq.auth.load_certificate(server_public_file)
        # The client must know the server's public key to make a CURVE connection.
        self.client.curve_serverkey = server_public
        self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/')
        self.client.setsockopt(zmq.SUBSCRIBE, self.node_name)
        self.client.connect(node['address'])

    def disconnect_client(self, node):
        self.client.disconnect(node['address'])

    def configureClient(self):
        print "Starting to receive Heart-beat"
        client = self.ctx.socket(zmq.SUB)
        # We need two certificates, one for the client and one for
        # the server. The client must know the server's public key
        # to make a CURVE connection.

        client_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['client_secret_key']
        client_secret_file = os.path.join(self.secret_keys_dir,
                                          client_secret_key_filename)
        client_public, client_secret = zmq.auth.load_certificate(
            client_secret_file)
        client.curve_secretkey = client_secret
        client.curve_publickey = client_public

        self.client = client

        for node in self.multinode_data['known_nodes']:
            self.connect_client(node)

    def pollClients(self, dbcon):
        if self.client.poll(1000):
            sender, topic, msg = dejsonify(self.client.recv())
            topic_list = topic.split('/')
            if topic_list[0] == 'heartbeat':
                node_name = sender
                is_parent = topic_list[2]
                last_sync_with_parent = topic_list[3]
                if self.multinode_status[node_name][
                        'health'] < 0:  #the node health was <0 , means offline
                    print node_name + " is back online"
                    self.recently_online_node_list += [node_name]
                    self.sync_node_with_parent(node_name)

                if is_parent.lower() in ['false', '0']:
                    self.multinode_status[node_name]['health'] = 1
                elif is_parent.lower() in ['true', '1']:
                    self.multinode_status[node_name]['health'] = 2
                    self.parent_node = node_name
                else:
                    raise ValueError(
                        'Invalid is_parent string in heart-beat message')

                self.multinode_status[node_name][
                    'last_online_time'] = datetime.now()
                self.multinode_status[node_name][
                    'last_sync_time'] = datetime.strptime(
                        last_sync_with_parent, '%B %d, %Y, %H:%M:%S')

            if topic_list[0] == self.node_name:
                if topic_list[1] == 'sync-with-parent':
                    pass
                    # print topic
                    # self.last_sync_with_parent = datetime.strptime(topic_list[2], '%B %d, %Y, %H:%M:%S')
                    # content = base64.b64decode(msg['database_dump'])
                    # newpath = 'bemossdb.sql'
                    # with open(newpath, 'w') as f:
                    #     f.write(content)
                    # try:
                    #     os.system(
                    #         'psql -c "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid();"')
                    #     os.system(
                    #         'dropdb bemossdb')  # This step requires all connections to be closed
                    #     os.system('createdb bemossdb -O admin')
                    #     dump_result = subprocess.check_output('psql bemossdb < ' + newpath, shell=True)
                    # except Exception as er:
                    #     print "Couldn't sync database with parent because of error: "
                    #     print er
                    #
                    # parent_node_name = topic_list[3]
                    # self.updateParent(parent_node_name)

                if topic_list[1] == 'republish':
                    target = msg['target']
                    actual_message = msg['actual_message']
                    actual_topic = msg['actual_topic']
                    self.bemoss_publish(target=target,
                                        topic=actual_topic + '/republished',
                                        message=actual_message,
                                        sender=sender)

            print self.node_name + ": " + topic, str(msg)

        else:
            time.sleep(2)

    def cleanup(self):
        # stop auth thread
        self.auth.stop()

    def updateMultinodeData(self, dbcon, sender, topic, message):
        print "Updating Multinode data"
        topic_list = topic.split('/')
        self.configure_authenticator()
        #to/multinodeagent/from/<doesn't matter>/update_multinode_data
        if topic_list[4] == 'update_multinode_data':
            old_multinode_data = self.multinode_data
            self.getMultinodeData()
            for node in self.multinode_data['known_nodes']:
                if node not in old_multinode_data['known_nodes']:
                    print "New node has been added to the cluster: " + node[
                        'name']
                    print "We will connect to it"
                    self.connect_client(node)

            for node in old_multinode_data['known_nodes']:
                if node not in self.multinode_data['known_nodes']:
                    print "Node has been removed from the cluster: " + node[
                        'name']
                    print "We will disconnect from it"
                    self.disconnect_client(node)
                    # TODO: remove it from the node_info table

        print "yay! got it"

    def relayDirectMessage(self, dbcon, sender, topic, message):
        print topic
        #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui>

        from_entity = sender
        target = message['target']
        actual_message = message['actual_message']
        actual_topic = message['actual_topic']

        for to_entity in target:
            if to_entity in settings.NO_FORWARD_AGENTS:
                return  #no forwarding should be done for these agents
            elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS:
                if not self.is_parent:
                    self.republishToParent(sender, topic, message)
            elif to_entity == "ALL":
                self.disperseMessage(sender, topic=topic, message=message)
            else:
                dbcon.execute(
                    "SELECT current_node_id FROM " + node_devices_table +
                    " WHERE agent_id=%s", (to_entity, ))
                if dbcon.rowcount:
                    node_id = dbcon.fetchone()[0]
                    if node_id != self.node_index:
                        self.server.send(
                            jsonify(
                                sender,
                                self.getNodeName(node_id) + '/republish/' +
                                topic, message))
                else:
                    self.disperseMessage(
                        sender, topic, message
                    )  #republish to all nodes if we don't know where to send
Exemple #18
0
class SecurityManager(object):
    """
    Manages the security related settings and actions.

    Attributes:
        private_cert_dir (str):
            The path towards the directory that stores private certificates.
        public_cert_dir (str):
            The path towards the directory that stores public certificates.
        temp_cert_dir (str):
            The path towards the temporary directory (used for generating
            new certificates). If not provided, it defaults to system's
            temporary directory.
        no_encryption (bool):
            Enable or disable encryption at peer level. Peers that use
            encryption can only connect to other peers that use encryption
            and vice-versa.
        public_file (str):
            The path towards the public certificate of this peer.
        private_file (str):
            The path towards the private certificate of this peer.
        auth_thread (ThreadAuthenticator):
            A separate thread used by zmq to authenticate our peers.

    """
    def __init__(self,
                 private_cert_dir,
                 public_cert_dir,
                 temp_cert_dir=None,
                 no_encryption=False,
                 *args,
                 **kwargs):
        """
        Constructor.

        Arguments:
            private_cert_dir (str):
                The path towards the directory that stores private
                certificates.
            public_cert_dir (str):
                The path towards the directory that stores public certificates.
            temp_cert_dir (str):
                The path towards the temporary directory (used for generating
                new certificates). Defaults to system's temporary directory.
            no_encryption (bool):
                Enable or disable encryption at peer level. Peers that use
                encryption can only connect to other peers that use encryption
                and vice-versa.
        """
        super(SecurityManager, self).__init__(*args, **kwargs)
        self.private_cert_dir = private_cert_dir
        self.public_cert_dir = public_cert_dir
        if temp_cert_dir is None:
            self.temp_cert_dir = tempfile.gettempdir()
        else:
            self.temp_cert_dir = temp_cert_dir

        # Paths for certificates used by this instance.
        self.public_file = None
        self.private_file = None

        # Security settings.
        self.no_encryption = no_encryption

        # Authentication thread.
        self.auth_thread = None

    def start_auth(self, context):
        """
        Starts the authentication thread if encryption is enabled.

        Arguments:
            context (zmq.Context):
                The context where the authentication thread will belong.
        """
        if not self.no_encryption:
            logger.debug("Authenticator thread is being started")
            self.auth_thread = ThreadAuthenticator(
                context=context,
                encoding='utf-8',
                log=logging.getLogger('zmq_auth'))
            self.auth_thread.start()
            self.auth_thread.thread.name = 'zmq_auth'

            self.auth_thread.configure_curve(domain='*',
                                             location=self.public_cert_dir)

    def terminate_auth(self):
        """
        Ends the authentication thread if encryption is enabled.

        This method should be written defensively, as the environment
        might not be fully set (an exception in
        :meth:`p2p0mq.app.theapp.LocalPeer.create` does not prevent
        this method from being executed).
        """
        if self.auth_thread is not None:
            logger.debug("Authenticator thread is being stopped")
            self.auth_thread.stop()
            self.auth_thread = None

    def prepare_cert_store(self, uuid):
        """
        Prepares the directory structure before it can
        be used by our authentication system.

        Arguments:
            uuid:
                The unique identification of local peer.
        """

        if not os.path.isdir(self.private_cert_dir):
            os.makedirs(self.private_cert_dir)
        if not os.path.isdir(self.public_cert_dir):
            os.makedirs(self.public_cert_dir)
        if not os.path.isdir(self.temp_cert_dir):
            os.makedirs(self.temp_cert_dir)

        self.public_file, self.private_file = \
            self.cert_pair_check_gen(uuid)

    def cert_pair_check_gen(self, uuid):
        """
        Checks if the certificates exist. Generates them if they don't.

        Arguments:
            uuid:
                The unique identification of the peer. Usually, this is
                the local peer.
        """
        cert_pub = self.cert_file_by_uuid(uuid, public=True)
        cert_prv = self.cert_file_by_uuid(uuid, public=False)
        pub_exists = os.path.isfile(cert_pub)
        prv_exists = os.path.isfile(cert_prv)

        if pub_exists and prv_exists:
            # Both files exist. Yey.
            pass
        elif pub_exists and not prv_exists:
            # The public certificate exists but is unusable without
            # the private one.
            raise RuntimeError(
                "The public certificate has been found at %s, "
                "which indicates that a key has been generated, "
                "but the private certificate is not at %s", cert_pub, cert_prv)
        elif not pub_exists and prv_exists:
            # The private certificate exists but the public one doesn't.
            # We can extract the key from the private one.
            with open(cert_prv, 'r') as fin:
                data = re.sub(r'.*private-key = "(.+)"', "", fin.read(),
                              re.MULTILINE)
            with open(cert_pub, 'w') as fout:
                fout.write(data)
        else:
            # Neither exists.
            public_file, secret_file = \
                zmq.auth.create_certificates(
                    self.temp_cert_dir,
                    '%r' % time())
            shutil.move(public_file, cert_pub)
            shutil.move(secret_file, cert_prv)
        return cert_pub, cert_prv

    def cert_file_by_uuid(self, uuid, public=True):
        """
        Computes the path of a certificate inside the certificate store
        based on the name of the peer.

        Arguments:
            uuid:
                The unique identification of the peer. Usually, this is
                the local peer.
            public (bool):
                If True it retrieves the path of the public certificate,
                if False it retrieves the path of the private certificate.
        """
        if isinstance(uuid, bytes):
            uuid = uuid.decode('utf-8')
        pb_vs_pv = 'key' if public else 'key_secret'
        return os.path.join(
            self.public_cert_dir if public else self.private_cert_dir,
            '%s.%s' % (uuid, pb_vs_pv))

    def cert_key_by_uuid(self, uuid, public=True):
        """
        Reads the key from corresponding certificate file.

        Arguments:
            uuid:
                The unique identification of the peer. Usually, this is
                the local peer.
            public (bool):
                If True it retrieves the public key,
                if False it retrieves the private key.
        """
        file = self.cert_file_by_uuid(uuid=uuid, public=public)
        logger.debug("%s certificate for uuid %s is loaded from %s",
                     'Public' if public else 'Private', uuid, file)
        if not os.path.exists(file):
            return None
        public_key, secret_key = zmq.auth.load_certificate(file)
        return public_key if public else secret_key

    def exchange_certificates(self, other):
        """
        Copies the certificates so that the two instances can
        authenticate themselves to each other.

        Arguments:
            other (SecurityManager):
                The security manager of the other peer.
        """
        if self.no_encryption:
            logger.error("Encryption was disabled in %s", self)
            return
        if other.no_encryption:
            logger.error("Encryption was disabled in %s", other)
            return

        shutil.copy(
            self.public_file,
            os.path.join(other.public_cert_dir,
                         os.path.basename(self.public_file)))
        shutil.copy(
            other.public_file,
            os.path.join(self.public_cert_dir,
                         os.path.basename(other.public_file)))
        self.auth_thread.configure_curve(domain='*',
                                         location=self.public_cert_dir)
        other.auth_thread.configure_curve(domain='*',
                                          location=other.public_cert_dir)
Exemple #19
0
class Command(LAVADaemonCommand):
    help = "LAVA log recorder"
    logger = None
    default_logfile = "/var/log/lava-server/lava-logs.log"

    def __init__(self, *args, **options):
        super(Command, self).__init__(*args, **options)
        self.logger = logging.getLogger("lava-logs")
        self.log_socket = None
        self.auth = None
        self.controler = None
        self.inotify_fd = None
        self.pipe_r = None
        self.poller = None
        self.cert_dir_path = None
        # List of logs
        self.jobs = {}
        # Keep test cases in memory
        self.test_cases = []
        # Master status
        self.last_ping = 0
        self.ping_interval = TIMEOUT

    def add_arguments(self, parser):
        super(Command, self).add_arguments(parser)

        net = parser.add_argument_group("network")
        net.add_argument('--socket',
                         default='tcp://*:5555',
                         help="Socket waiting for logs. Default: tcp://*:5555")
        net.add_argument('--master-socket',
                         default='tcp://localhost:5556',
                         help="Socket for master-slave communication. Default: tcp://localhost:5556")
        net.add_argument('--ipv6', default=False, action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt', default=False, action='store_true',
                         help="Encrypt messages")
        net.add_argument('--master-cert',
                         default='/etc/lava-dispatcher/certificates.d/master.key_secret',
                         help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-logs", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        # Create the sockets
        context = zmq.Context()
        self.log_socket = context.socket(zmq.PULL)
        self.controler = context.socket(zmq.ROUTER)
        self.controler.setsockopt(zmq.IDENTITY, b"lava-logs")
        # Limit the number of messages in the queue
        self.controler.setsockopt(zmq.SNDHWM, 2)
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5
        # "Immediately readies that connection for data transfer with the master"
        self.controler.setsockopt(zmq.CONNECT_RID, b"master")

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.log_socket.setsockopt(zmq.IPV6, 1)
            self.controler.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs'])
                self.auth.configure_curve(domain='*', location=options['slaves_certs'])
            except IOError as err:
                self.logger.error("[INIT] %s", err)
                self.auth.stop()
                return
            self.log_socket.curve_publickey = master_public
            self.log_socket.curve_secretkey = master_secret
            self.log_socket.curve_server = True
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_serverkey = master_public

        self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
        self.cert_dir_path = options["slaves_certs"]
        self.inotify_fd = watch_directory(options["slaves_certs"])
        if self.inotify_fd is None:
            self.logger.error("[INIT] Unable to start inotify")

        self.log_socket.bind(options['socket'])
        self.controler.connect(options['master_socket'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.log_socket, zmq.POLLIN)
        self.poller.register(self.controler, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] listening for logs")
        # PING right now: the master is waiting for this message to start
        # scheduling.
        self.controler.send_multipart([b"master", b"PING"])

        try:
            self.main_loop()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)

        # Close the controler socket
        self.controler.close(linger=0)
        self.poller.unregister(self.controler)

        # Carefully close the logging socket as we don't want to lose messages
        self.logger.info("[EXIT] Disconnect logging socket and process messages")
        endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT))
        self.logger.debug("[EXIT] unbinding from '%s'", endpoint)
        self.log_socket.unbind(endpoint)

        # Empty the queue
        try:
            while self.wait_for_messages(True):
                # Flush test cases cache for every iteration because we might
                # get killed soon.
                self.flush_test_cases()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Last flush
            self.flush_test_cases()
            self.logger.info("[EXIT] Closing the logging socket: the queue is empty")
            self.log_socket.close()
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def flush_test_cases(self):
        if self.test_cases:
            self.logger.info("Saving %d test cases", len(self.test_cases))
            TestCase.objects.bulk_create(self.test_cases)
            self.test_cases = []

    def main_loop(self):
        last_gc = time.time()
        last_bulk_create = time.time()

        # Wait for messages
        # TODO: fix timeout computation
        while self.wait_for_messages(False):
            now = time.time()

            # Dump TestCase into the database
            if now - last_bulk_create > BULK_CREATE_TIMEOUT:
                last_bulk_create = now
                self.flush_test_cases()

            # Close old file handlers
            if now - last_gc > FD_TIMEOUT:
                last_gc = now
                # Iterate while removing keys is not compatible with iterator
                for job_id in list(self.jobs.keys()):  # pylint: disable=consider-iterating-dictionary
                    if now - self.jobs[job_id].last_usage > FD_TIMEOUT:
                        self.logger.info("[%s] closing log file", job_id)
                        self.jobs[job_id].close()
                        del self.jobs[job_id]

            # Ping the master
            if now - self.last_ping > self.ping_interval:
                self.logger.debug("PING => master")
                self.last_ping = now
                self.controler.send_multipart([b"master", b"PING"])

    def wait_for_messages(self, leaving):
        try:
            try:
                sockets = dict(self.poller.poll(TIMEOUT * 1000))
            except zmq.error.ZMQError as exc:
                self.logger.error("[POLL] zmq error: %s", str(exc))
                return True

            # Messages
            if sockets.get(self.log_socket) == zmq.POLLIN:
                self.logging_socket()
                return True

            # Signals
            elif sockets.get(self.pipe_r) == zmq.POLLIN:
                # remove the message from the queue
                os.read(self.pipe_r, 1)

                if not leaving:
                    self.logger.info("[POLL] received a signal, leaving")
                    return False
                else:
                    self.logger.warning("[POLL] signal already handled, please wait for the process to exit")
                    return True

            # Pong received
            elif sockets.get(self.controler) == zmq.POLLIN:
                self.controler_socket()
                return True

            # Inotify socket
            if sockets.get(self.inotify_fd) == zmq.POLLIN:
                os.read(self.inotify_fd, 4096)
                self.logger.debug("[AUTH] Reloading certificates from %s",
                                  self.cert_dir_path)
                self.auth.configure_curve(domain='*',
                                          location=self.cert_dir_path)

            # Nothing received
            else:
                return not leaving

        except (OperationalError, InterfaceError):
            self.logger.info("[RESET] database connection reset")
            connection.close()
        return True

    def logging_socket(self):
        msg = self.log_socket.recv_multipart()
        try:
            (job_id, message) = (u(m) for m in msg)  # pylint: disable=unbalanced-tuple-unpacking
        except ValueError:
            # do not let a bad message stop the master.
            self.logger.error("[POLL] failed to parse log message, skipping: %s", msg)
            return

        try:
            scanned = yaml.load(message, Loader=yaml.CLoader)
        except yaml.YAMLError:
            self.logger.error("[%s] data are not valid YAML, dropping", job_id)
            return

        # Look for "results" level
        try:
            message_lvl = scanned["lvl"]
            message_msg = scanned["msg"]
        except TypeError:
            self.logger.error("[%s] not a dictionary, dropping", job_id)
            return
        except KeyError:
            self.logger.error(
                "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s",
                job_id, message)
            return

        # Find the handler (if available)
        if job_id not in self.jobs:
            # Query the database for the job
            try:
                job = TestJob.objects.get(id=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return

            self.logger.info("[%s] receiving logs from a new job", job_id)
            # Create the sub directories (if needed)
            mkdir(job.output_dir)
            self.jobs[job_id] = JobHandler(job)

        if message_lvl == "results":
            try:
                job = TestJob.objects.get(pk=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return
            meta_filename = create_metadata_store(message_msg, job)
            new_test_case = map_scanned_results(results=message_msg, job=job,
                                                meta_filename=meta_filename)
            if new_test_case is None:
                self.logger.warning(
                    "[%s] unable to map scanned results: %s",
                    job_id, message)
            else:
                self.test_cases.append(new_test_case)

            # Look for lava.job result
            if message_msg.get("definition") == "lava" and message_msg.get("case") == "job":
                # Flush cached test cases
                self.flush_test_cases()

                if message_msg.get("result") == "pass":
                    health = TestJob.HEALTH_COMPLETE
                    health_msg = "Complete"
                else:
                    health = TestJob.HEALTH_INCOMPLETE
                    health_msg = "Incomplete"
                self.logger.info("[%s] job status: %s", job_id, health_msg)

                infrastructure_error = (message_msg.get("error_type") in ["Bug",
                                                                          "Configuration",
                                                                          "Infrastructure"])
                if infrastructure_error:
                    self.logger.info("[%s] Infrastructure error", job_id)

                # Update status.
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)
                    job.go_state_finished(health, infrastructure_error)
                    job.save()

        # Mark the file handler as used
        self.jobs[job_id].last_usage = time.time()

        # n.b. logging here would produce a log entry for every message in every job.
        # The format is a list of dictionaries
        message = "- %s" % message

        # Write data
        self.jobs[job_id].write(message)

    def controler_socket(self):
        msg = self.controler.recv_multipart()
        try:
            master_id = u(msg[0])
            action = u(msg[1])
            ping_interval = int(msg[2])

            if master_id != "master":
                self.logger.error("Invalid master id '%s'. Should be 'master'",
                                  master_id)
                return
            if action != "PONG":
                self.logger.error("Invalid answer '%s'. Should be 'PONG'",
                                  action)
                return
        except (IndexError, ValueError):
            self.logger.error("Invalid message '%s'", msg)
            return

        if ping_interval < TIMEOUT:
            self.logger.error("invalid ping interval (%d) too small", ping_interval)
            return

        self.logger.debug("master => PONG(%d)", ping_interval)
        self.ping_interval = ping_interval
Exemple #20
0
class RpcClient:
    """"""
    def __init__(self):
        """Constructor"""
        # zmq port related
        self.__context: zmq.Context = zmq.Context()

        # Request socket (Request–reply pattern)
        self.__socket_req: zmq.Socket = self.__context.socket(zmq.REQ)

        # Subscribe socket (Publish–subscribe pattern)
        self.__socket_sub: zmq.Socket = self.__context.socket(zmq.SUB)

        # Worker thread relate, used to process data pushed from server
        self.__active: bool = False  # RpcClient status
        self.__thread: threading.Thread = None  # RpcClient thread
        self.__lock: threading.Lock = threading.Lock()

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

        self._last_received_ping: datetime = datetime.utcnow()

    @lru_cache(100)
    def __getattr__(self, name: str):
        """
        Realize remote call function
        """

        # Perform remote call task
        def dorpc(*args, **kwargs):
            # Get timeout value from kwargs, default value is 30 seconds
            if "timeout" in kwargs:
                timeout = kwargs.pop("timeout")
            else:
                timeout = 30000

            # Generate request
            req = [name, args, kwargs]

            # Send request and wait for response
            with self.__lock:
                self.__socket_req.send_pyobj(req)

                # Timeout reached without any data
                n = self.__socket_req.poll(timeout)
                if not n:
                    msg = f"Timeout of {timeout}ms reached for {req}"
                    raise RemoteException(msg)

                rep = self.__socket_req.recv_pyobj()

            # Return response if successed; Trigger exception if failed
            if rep[0]:
                return rep[1]
            else:
                raise RemoteException(rep[1])

        return dorpc

    def start(self,
              req_address: str,
              sub_address: str,
              client_secretkey_path: str = "",
              server_publickey_path: str = "",
              username: str = "",
              password: str = "") -> None:
        """
        Start RpcClient
        """
        if self.__active:
            return

        # Start authenticator
        if client_secretkey_path and server_publickey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                client_secretkey_path)
            serverkey, _ = zmq.auth.load_certificate(server_publickey_path)

            self.__socket_sub.curve_secretkey = secretkey
            self.__socket_sub.curve_publickey = publickey
            self.__socket_sub.curve_serverkey = serverkey

            self.__socket_req.curve_secretkey = secretkey
            self.__socket_req.curve_publickey = publickey
            self.__socket_req.curve_serverkey = serverkey
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_sub.plain_username = username.encode()
            self.__socket_sub.plain_password = password.encode()

            self.__socket_req.plain_username = username.encode()
            self.__socket_req.plain_password = password.encode()

        # Connect zmq port
        self.__socket_req.connect(req_address)
        self.__socket_sub.connect(sub_address)

        # Start RpcClient status
        self.__active = True

        # Start RpcClient thread
        self.__thread = threading.Thread(target=self.run)
        self.__thread.start()

        self._last_received_ping = datetime.utcnow()

    def stop(self) -> None:
        """
        Stop RpcClient
        """
        if not self.__active:
            return

        # Stop RpcClient status
        self.__active = False

    def join(self) -> None:
        # Wait for RpcClient thread to exit
        if self.__thread and self.__thread.is_alive():
            self.__thread.join()
        self.__thread = None

    def run(self) -> None:
        """
        Run RpcClient function
        """
        pull_tolerance = int(KEEP_ALIVE_TOLERANCE.total_seconds() * 1000)

        while self.__active:
            if not self.__socket_sub.poll(pull_tolerance):
                self.on_disconnected()
                continue

            # Receive data from subscribe socket
            topic, data = self.__socket_sub.recv_pyobj(flags=NOBLOCK)

            if topic == KEEP_ALIVE_TOPIC:
                self._last_received_ping = data
            else:
                # Process data by callable function
                self.callback(topic, data)

        # Close socket
        self.__socket_req.close()
        self.__socket_sub.close()

        if self.__authenticator:
            self.__authenticator.stop()

    def callback(self, topic: str, data: Any) -> None:
        """
        Callable function
        """
        raise NotImplementedError

    def subscribe_topic(self, topic: str) -> None:
        """
        Subscribe data
        """
        self.__socket_sub.setsockopt_string(zmq.SUBSCRIBE, topic)

    def on_disconnected(self):
        """
        Callback when heartbeat is lost.
        """
        print(
            "RpcServer has no response over {tolerance} seconds, please check you connection."
            .format(tolerance=KEEP_ALIVE_TOLERANCE.total_seconds()))
Exemple #21
0
class StratusApp(StratusServerApp):
    def __init__(self, core: StratusCore, **kwargs):
        StratusServerApp.__init__(self, core, **kwargs)
        self.logger = StratusLogger.getLogger()
        self.active = True
        self.parms = self.getConfigParms('stratus')
        self.client_address = self.parms.get("client_address", "*")
        self.request_port = self.parms.get("request_port", 4556)
        self.response_port = self.parms.get("response_port", 4557)
        self.active_handlers = {}
        self.getCertDirs()

    def getCertDirs(
        self
    ):  # These directories are generated by the generate_certificates script
        self.cert_dir = self.parms.get("certificate_path",
                                       os.path.expanduser("~/.stratus/zmq"))
        self.logger.info(
            f"Loading certificates and keys from directory {self.cert_dir}")
        self.keys_dir = os.path.join(self.cert_dir, 'certificates')
        self.public_keys_dir = os.path.join(self.cert_dir, 'public_keys')
        self.secret_keys_dir = os.path.join(self.cert_dir, 'private_keys')

        if not (os.path.exists(self.keys_dir)
                and os.path.exists(self.public_keys_dir)
                and os.path.exists(self.secret_keys_dir)):
            from stratus.handlers.zeromq.security.generate_certificates import generate_certificates
            generate_certificates(self.cert_dir)

    def initSocket(self):
        try:
            server_secret_file = os.path.join(self.secret_keys_dir,
                                              "server.key_secret")
            server_public, server_secret = zmq.auth.load_certificate(
                server_secret_file)
            # TODO: this is commented to avoid key checking
            #self.request_socket.curve_secretkey = server_secret
            #self.request_socket.curve_publickey = server_public
            #self.request_socket.curve_server = True
            self.request_socket.bind("tcp://{}:{}".format(
                self.client_address, self.request_port))
            self.logger.info(
                "@@STRATUS-APP --> Bound authenticated request socket to client at {} on port: {}"
                .format(self.client_address, self.request_port))
        except Exception as err:
            self.logger.error(
                "@@STRATUS-APP: Error initializing request socket on {}, port {}: {}"
                .format(self.client_address, self.request_port, err))
            self.logger.error(traceback.format_exc())

    def addHandler(self, clientId, jobId, handler):
        self.active_handlers[clientId + "-" + jobId] = handler
        return handler

    def removeHandler(self, clientId, jobId):
        handlerId = clientId + "-" + jobId
        try:
            del self.active_handlers[handlerId]
        except:
            self.logger.error("Error removing handler: " + handlerId +
                              ", active handlers = " +
                              str(list(self.active_handlers.keys())))

    def setExeStatus(self, submissionId: str, status: Status):
        self.responder.setExeStatus(submissionId, status)

    def sendResponseMessage(self, msg: StratusResponse) -> str:
        request_args = [msg.id, msg.message]
        packaged_msg = "!".join(request_args)
        timeStamp = datetime.datetime.now().strftime("MM/dd HH:mm:ss")
        self.logger.info(
            "@@STRATUS-APP: Sending response {} on request_socket @({}): {}".
            format(msg.id, timeStamp, str(msg)))
        self.request_socket.send_string(packaged_msg)
        return packaged_msg

    def initInteractions(self):
        try:
            self.zmqContext: zmq.Context = zmq.Context()

            self.auth = ThreadAuthenticator(self.zmqContext)
            self.auth.start()
            self.auth.allow("192.168.0.22")
            self.auth.allow(self.client_address)
            self.auth.configure_curve(
                domain='*', location=zmq.auth.CURVE_ALLOW_ANY
            )  # self.public_keys_dir )  # Use 'location=zmq.auth.CURVE_ALLOW_ANY' for stonehouse security

            self.request_socket: zmq.Socket = self.zmqContext.socket(zmq.REP)
            self.responder = StratusZMQResponder(
                self.zmqContext,
                self.response_port,
                client_address=self.client_address,
                certificate_path=self.cert_dir)
            self.initSocket()
            self.logger.info(
                "@@STRATUS-APP:Listening for requests on port: {}".format(
                    self.request_port))

        except Exception as err:
            self.logger.error(
                "@@STRATUS-APP:  ------------------------------- StratusApp Init error: {} ------------------------------- "
                .format(err))

    def processResults(self):
        completed_workflows = self.responder.processWorkflows(
            self.getWorkflows())
        for rid in completed_workflows:
            self.clearWorkflow(rid)

    def processRequests(self):
        while self.request_socket.poll(0) != 0:
            request_header = self.request_socket.recv_string().strip().strip(
                "'")
            parts = request_header.split("!")
            submissionId = str(parts[0])
            rType = str(parts[1])
            request: Dict = json.loads(parts[2]) if len(parts) > 2 else ""
            try:
                self.logger.info(
                    "@@STRATUS-APP:  ###  Processing {} request: {}".format(
                        rType, request))
                if rType == "capabilities":
                    response = self.core.getCapabilities(request["type"])
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
                elif rType == "exe":
                    if len(parts) <= 2:
                        raise Exception("Missing parameters to exe request")
                    request["rid"] = submissionId
                    self.logger.info(
                        "Processing zmq Request: '{}' '{}' '{}'".format(
                            submissionId, rType, str(request)))
                    self.submitWorkflow(
                        request)  #   TODO: Send results when tasks complete.
                    response = {"status": "Executing"}
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
                elif rType == "quit" or rType == "shutdown":
                    response = {"status": "Terminating"}
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
                    self.logger.info(
                        "@@STRATUS-APP: Received Shutdown Message")
                    exit(0)
                else:
                    msg = "@@STRATUS-APP: Unknown request type: " + rType
                    self.logger.info(msg)
                    response = {"status": "error", "error": msg}
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
            except Exception as ex:
                self.processError(submissionId, ex)

    def processError(self, rid: str, ex: Exception):
        tb = traceback.format_exc()
        self.logger.error("@@STRATUS-APP: Execution error: " + str(ex))
        self.logger.error(tb)
        response = {"status": "error", "error": str(ex), "traceback": tb}
        self.sendResponseMessage(StratusResponse(rid, response))

    def updateInteractions(self):
        self.processRequests()
        self.processResults()

    def term(self, msg):
        self.logger.info("@@STRATUS-APP: !!EDAS Shutdown: " + msg)
        self.active = False
        self.auth.stop()
        self.logger.info("@@STRATUS-APP: QUIT PythonWorkerPortal")
        try:
            self.request_socket.close()
        except Exception:
            pass
        self.logger.info("@@STRATUS-APP: CLOSE request_socket")
        self.responder.close_connection()
        self.logger.info("@@STRATUS-APP: TERM responder")
        self.shutdown()
        self.logger.info("@@STRATUS-APP: shutdown complete")
Exemple #22
0
class RpcServer:
    """"""
    def __init__(self):
        """
        Constructor
        """
        # Save functions dict: key is fuction name, value is fuction object
        self.__functions: Dict[str, Any] = {}

        # Zmq port related
        self.__context: zmq.Context = zmq.Context()

        # Reply socket (Request–reply pattern)
        self.__socket_rep: zmq.Socket = self.__context.socket(zmq.REP)

        # Publish socket (Publish–subscribe pattern)
        self.__socket_pub: zmq.Socket = self.__context.socket(zmq.PUB)

        # Worker thread related
        self.__active: bool = False  # RpcServer status
        self.__thread: threading.Thread = None  # RpcServer thread
        self.__lock: threading.Lock = threading.Lock()

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

    def is_active(self) -> bool:
        """"""
        return self.__active

    def start(self,
              rep_address: str,
              pub_address: str,
              server_secretkey_path: str = "",
              username: str = "",
              password: str = "") -> None:
        """
        Start RpcServer
        """
        if self.__active:
            return

        # Start authenticator
        if server_secretkey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                server_secretkey_path)

            self.__socket_pub.curve_secretkey = secretkey
            self.__socket_pub.curve_publickey = publickey
            self.__socket_pub.curve_server = True

            self.__socket_rep.curve_secretkey = secretkey
            self.__socket_rep.curve_publickey = publickey
            self.__socket_rep.curve_server = True
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_pub.plain_server = True
            self.__socket_rep.plain_server = True

        # Bind socket address
        self.__socket_rep.bind(rep_address)
        self.__socket_pub.bind(pub_address)

        # Start RpcServer status
        self.__active = True

        # Start RpcServer thread
        self.__thread = threading.Thread(target=self.run)
        self.__thread.start()

    def stop(self) -> None:
        """
        Stop RpcServer
        """
        if not self.__active:
            return

        # Stop RpcServer status
        self.__active = False

    def join(self) -> None:
        # Wait for RpcServer thread to exit
        if self.__thread and self.__thread.is_alive():
            self.__thread.join()
        self.__thread = None

    def run(self) -> None:
        """
        Run RpcServer functions
        """
        start = datetime.utcnow()

        while self.__active:
            # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
            cur = datetime.utcnow()
            delta = cur - start

            if delta >= KEEP_ALIVE_INTERVAL:
                self.publish(KEEP_ALIVE_TOPIC, cur)

            if not self.__socket_rep.poll(1000):
                continue

            # Receive request data from Reply socket
            req = self.__socket_rep.recv_pyobj()

            # Get function name and parameters
            name, args, kwargs = req

            # Try to get and execute callable function object; capture exception information if it fails
            try:
                func = self.__functions[name]
                r = func(*args, **kwargs)
                rep = [True, r]
            except Exception as e:  # noqa
                rep = [False, traceback.format_exc()]

            # send callable response by Reply socket
            self.__socket_rep.send_pyobj(rep)

        # Unbind socket address
        self.__socket_pub.unbind(self.__socket_pub.LAST_ENDPOINT)
        self.__socket_rep.unbind(self.__socket_rep.LAST_ENDPOINT)

        if self.__authenticator:
            self.__authenticator.stop()

    def publish(self, topic: str, data: Any) -> None:
        """
        Publish data
        """
        with self.__lock:
            self.__socket_pub.send_pyobj([topic, data])

    def register(self, func: Callable) -> None:
        """
        Register function
        """
        self.__functions[func.__name__] = func
Exemple #23
0
class ZmqReceiver(object):
    def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None):
        self.context = zmq.Context()
        self.auth = None
        self.last_received_message = None
        self.is_running = False
        self.thread = None
        self.zmq_rep_bind_address = zmq_rep_bind_address
        self.zmq_sub_connect_addresses = zmq_sub_connect_addresses
        self.poller = zmq.Poller()
        self.sub_sockets = []
        self.rep_socket = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I (probably because the more secure solutions
            # require two way communication as well)
            self.auth = ThreadAuthenticator(self.context)
            self.auth.start()
            # Instruct authenticator to handle PLAIN requests
            self.auth.configure_plain(domain='*', passwords={username: password})

        if self.zmq_sub_connect_addresses:
            for address in self.zmq_sub_connect_addresses:
                self.sub_sockets.append(SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec))
        if zmq_rep_bind_address:
            self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth)

    # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds
    def stop(self):
        self.is_running = False
        logger.info("Closing pub and sub sockets...")
        if self.auth is not None:
            self.auth.stop()

    def run(self):
        self.is_running = True

        while self.is_running:
            socks = dict(self.poller.poll(1000))
            logger.debug("Poll cycle over. checking sockets")
            if self.rep_socket:
                incoming_message = self.rep_socket.recv_string(socks)
                if incoming_message is not None:
                    self.last_received_message = incoming_message
                    try:
                        logger.debug("Got info from REP socket")
                        response_message = self.handle_incoming_message(incoming_message)
                        self.rep_socket.send(response_message)
                    except Exception as e:
                        logger.error(e)
            for sub_socket in self.sub_sockets:
                incoming_message = sub_socket.recv_string(socks)
                if incoming_message is not None:
                    if incoming_message != "zmq_sub_heartbeat":
                        self.last_received_message = incoming_message
                    logger.debug("Got info from SUB socket")
                    try:
                        self.handle_incoming_message(incoming_message)
                    except Exception as e:
                        logger.error(e)

        if self.rep_socket:
            self.rep_socket.destroy()
        for sub_socket in self.sub_sockets:
            sub_socket.destroy()

    def create_response_message(self, status_code, status_message, response_message):
        if response_message is not None:
            return json.dumps({"status_code": status_code, "status_message": status_message, "response_message": response_message})
        else:
            return json.dumps({"status_code": status_code, "status_message": status_message})

    def handle_incoming_message(self, message):
        if message != "zmq_sub_heartbeat":
            return self.create_response_message(200, "OK", None)
Exemple #24
0
class ZmqReceiver(object):
    def __init__(self,
                 zmq_rep_bind_address=None,
                 zmq_sub_connect_addresses=None,
                 recreate_sockets_on_timeout_of_sec=600,
                 username=None,
                 password=None):
        self.context = zmq.Context()
        self.auth = None
        self.last_received_message = None
        self.is_running = False
        self.thread = None
        self.zmq_rep_bind_address = zmq_rep_bind_address
        self.zmq_sub_connect_addresses = zmq_sub_connect_addresses
        self.poller = zmq.Poller()
        self.sub_sockets = []
        self.rep_socket = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I (probably because the more secure solutions
            # require two way communication as well)
            self.auth = ThreadAuthenticator(self.context)
            self.auth.start()
            # Instruct authenticator to handle PLAIN requests
            self.auth.configure_plain(domain='*',
                                      passwords={username: password})

        if self.zmq_sub_connect_addresses:
            for address in self.zmq_sub_connect_addresses:
                self.sub_sockets.append(
                    SubSocket(self.context, self.poller, address,
                              recreate_sockets_on_timeout_of_sec))
        if zmq_rep_bind_address:
            self.rep_socket = RepSocket(self.context, self.poller,
                                        zmq_rep_bind_address, self.auth)

    # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds
    def stop(self):
        self.is_running = False
        logger.info("Closing pub and sub sockets...")
        if self.auth is not None:
            self.auth.stop()

    def run(self):
        self.is_running = True

        while self.is_running:
            socks = dict(self.poller.poll(1000))
            logger.debug("Poll cycle over. checking sockets")
            if self.rep_socket:
                incoming_message = self.rep_socket.recv_string(socks)
                if incoming_message is not None:
                    self.last_received_message = incoming_message
                    try:
                        logger.debug("Got info from REP socket")
                        response_message = self.handle_incoming_message(
                            incoming_message)
                        self.rep_socket.send(response_message)
                    except Exception as e:
                        logger.error(e)
            for sub_socket in self.sub_sockets:
                incoming_message = sub_socket.recv_string(socks)
                if incoming_message is not None:
                    if incoming_message != "zmq_sub_heartbeat":
                        self.last_received_message = incoming_message
                    logger.debug("Got info from SUB socket")
                    try:
                        self.handle_incoming_message(incoming_message)
                    except Exception as e:
                        logger.error(e)

        if self.rep_socket:
            self.rep_socket.destroy()
        for sub_socket in self.sub_sockets:
            sub_socket.destroy()

    def create_response_message(self, status_code, status_message,
                                response_message):
        if response_message is not None:
            return json.dumps({
                "status_code": status_code,
                "status_message": status_message,
                "response_message": response_message
            })
        else:
            return json.dumps({
                "status_code": status_code,
                "status_message": status_message
            })

    def handle_incoming_message(self, message):
        if message != "zmq_sub_heartbeat":
            return self.create_response_message(200, "OK", None)
Exemple #25
0
class Command(LAVADaemonCommand):
    help = "LAVA log recorder"
    logger = None
    default_logfile = "/var/log/lava-server/lava-logs.log"

    def __init__(self, *args, **options):
        super().__init__(*args, **options)
        self.logger = logging.getLogger("lava-logs")
        self.log_socket = None
        self.auth = None
        self.controler = None
        self.inotify_fd = None
        self.pipe_r = None
        self.poller = None
        self.cert_dir_path = None
        # List of logs
        self.jobs = {}
        # Keep test cases in memory
        self.test_cases = []
        # Master status
        self.last_ping = 0
        self.ping_interval = TIMEOUT

    def add_arguments(self, parser):
        super().add_arguments(parser)

        net = parser.add_argument_group("network")
        net.add_argument('--socket',
                         default='tcp://*:5555',
                         help="Socket waiting for logs. Default: tcp://*:5555")
        net.add_argument('--master-socket',
                         default='tcp://localhost:5556',
                         help="Socket for master-slave communication. Default: tcp://localhost:5556")
        net.add_argument('--ipv6', default=False, action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt', default=False, action='store_true',
                         help="Encrypt messages")
        net.add_argument('--master-cert',
                         default='/etc/lava-dispatcher/certificates.d/master.key_secret',
                         help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-logs", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        filename = os.path.join(settings.MEDIA_ROOT, 'lava-logs-config.yaml')
        self.logger.debug("[INIT] Dumping config to %s", filename)
        with open(filename, 'w') as output:
            yaml.dump(options, output)

        # Create the sockets
        context = zmq.Context()
        self.log_socket = context.socket(zmq.PULL)
        self.controler = context.socket(zmq.ROUTER)
        self.controler.setsockopt(zmq.IDENTITY, b"lava-logs")
        # Limit the number of messages in the queue
        self.controler.setsockopt(zmq.SNDHWM, 2)
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5
        # "Immediately readies that connection for data transfer with the master"
        self.controler.setsockopt(zmq.CONNECT_RID, b"master")

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.log_socket.setsockopt(zmq.IPV6, 1)
            self.controler.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs'])
                self.auth.configure_curve(domain='*', location=options['slaves_certs'])
            except OSError as err:
                self.logger.error("[INIT] %s", err)
                self.auth.stop()
                return
            self.log_socket.curve_publickey = master_public
            self.log_socket.curve_secretkey = master_secret
            self.log_socket.curve_server = True
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_serverkey = master_public

        self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
        self.cert_dir_path = options["slaves_certs"]
        self.inotify_fd = watch_directory(options["slaves_certs"])
        if self.inotify_fd is None:
            self.logger.error("[INIT] Unable to start inotify")

        self.log_socket.bind(options['socket'])
        self.controler.connect(options['master_socket'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.log_socket, zmq.POLLIN)
        self.poller.register(self.controler, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] listening for logs")
        # PING right now: the master is waiting for this message to start
        # scheduling.
        self.controler.send_multipart([b"master", b"PING"])

        try:
            self.main_loop()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)

        # Close the controler socket
        self.controler.close(linger=0)
        self.poller.unregister(self.controler)

        # Carefully close the logging socket as we don't want to lose messages
        self.logger.info("[EXIT] Disconnect logging socket and process messages")
        endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT))
        self.logger.debug("[EXIT] unbinding from '%s'", endpoint)
        self.log_socket.unbind(endpoint)

        # Empty the queue
        try:
            while self.wait_for_messages(True):
                # Flush test cases cache for every iteration because we might
                # get killed soon.
                self.flush_test_cases()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Last flush
            self.flush_test_cases()
            self.logger.info("[EXIT] Closing the logging socket: the queue is empty")
            self.log_socket.close()
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def flush_test_cases(self):
        if not self.test_cases:
            return

        # Try to save into the database
        try:
            TestCase.objects.bulk_create(self.test_cases)
            self.logger.info("Saving %d test cases", len(self.test_cases))
            self.test_cases = []
        except DatabaseError as exc:
            self.logger.error("Unable to flush the test cases")
            self.logger.exception(exc)
            self.logger.warning("Saving test cases one by one and dropping the faulty ones")
            saved = 0
            for tc in self.test_cases:
                with contextlib.suppress(DatabaseError):
                    tc.save()
                    saved += 1
            self.logger.info("%d test cases saved, %d dropped", saved, len(self.test_cases) - saved)
            self.test_cases = []

    def main_loop(self):
        last_gc = time.time()
        last_bulk_create = time.time()

        # Wait for messages
        # TODO: fix timeout computation
        while self.wait_for_messages(False):
            now = time.time()

            # Dump TestCase into the database
            if now - last_bulk_create > BULK_CREATE_TIMEOUT:
                last_bulk_create = now
                self.flush_test_cases()

            # Close old file handlers
            if now - last_gc > FD_TIMEOUT:
                last_gc = now
                # Iterate while removing keys is not compatible with iterator
                for job_id in list(self.jobs.keys()):  # pylint: disable=consider-iterating-dictionary
                    if now - self.jobs[job_id].last_usage > FD_TIMEOUT:
                        self.logger.info("[%s] closing log file", job_id)
                        self.jobs[job_id].close()
                        del self.jobs[job_id]

            # Ping the master
            if now - self.last_ping > self.ping_interval:
                self.logger.debug("PING => master")
                self.last_ping = now
                self.controler.send_multipart([b"master", b"PING"])

    def wait_for_messages(self, leaving):
        try:
            try:
                sockets = dict(self.poller.poll(TIMEOUT * 1000))
            except zmq.error.ZMQError as exc:
                self.logger.error("[POLL] zmq error: %s", str(exc))
                return True

            # Messages
            if sockets.get(self.log_socket) == zmq.POLLIN:
                self.logging_socket()
                return True

            # Signals
            elif sockets.get(self.pipe_r) == zmq.POLLIN:
                # remove the message from the queue
                os.read(self.pipe_r, 1)

                if not leaving:
                    self.logger.info("[POLL] received a signal, leaving")
                    return False
                else:
                    self.logger.warning("[POLL] signal already handled, please wait for the process to exit")
                    return True

            # Pong received
            elif sockets.get(self.controler) == zmq.POLLIN:
                self.controler_socket()
                return True

            # Inotify socket
            if sockets.get(self.inotify_fd) == zmq.POLLIN:
                os.read(self.inotify_fd, 4096)
                self.logger.debug("[AUTH] Reloading certificates from %s",
                                  self.cert_dir_path)
                self.auth.configure_curve(domain='*',
                                          location=self.cert_dir_path)

            # Nothing received
            else:
                return not leaving

        except (OperationalError, InterfaceError):
            self.logger.info("[RESET] database connection reset")
            connection.close()
        return True

    def logging_socket(self):
        msg = self.log_socket.recv_multipart()
        try:
            (job_id, message) = (u(m) for m in msg)  # pylint: disable=unbalanced-tuple-unpacking
        except ValueError:
            # do not let a bad message stop the master.
            self.logger.error("[POLL] failed to parse log message, skipping: %s", msg)
            return

        try:
            scanned = yaml.load(message, Loader=yaml.CLoader)
        except yaml.YAMLError:
            self.logger.error("[%s] data are not valid YAML, dropping", job_id)
            return

        # Look for "results" level
        try:
            message_lvl = scanned["lvl"]
            message_msg = scanned["msg"]
        except TypeError:
            self.logger.error("[%s] not a dictionary, dropping", job_id)
            return
        except KeyError:
            self.logger.error(
                "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s",
                job_id, message)
            return

        # Find the handler (if available)
        if job_id not in self.jobs:
            # Query the database for the job
            try:
                job = TestJob.objects.get(id=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return

            self.logger.info("[%s] receiving logs from a new job", job_id)
            # Create the sub directories (if needed)
            mkdir(job.output_dir)
            self.jobs[job_id] = JobHandler(job)

        # For 'event', send an event and log as 'debug'
        if message_lvl == 'event':
            self.logger.debug("[%s] event: %s", job_id, message_msg)
            send_event(".event", "lavaserver", {"message": message_msg, "job": job_id})
            message_lvl = "debug"
        # For 'marker', save in the database and log as 'debug'
        elif message_lvl == 'marker':
            # TODO: save on the file system in case of lava-logs restart
            m_type = message_msg.get("type")
            case = message_msg.get("case")
            if m_type is None or case is None:
                self.logger.error("[%s] invalid marker: %s", job_id, message_msg)
                return
            self.jobs[job_id].markers.setdefault(case, {})[m_type] = self.jobs[job_id].line_count()
            # This is in fact the previous line
            self.jobs[job_id].markers[case][m_type] -= 1
            self.logger.debug("[%s] marker: %s line: %s", job_id, message_msg, self.jobs[job_id].markers[case][m_type])
            return

        # Mark the file handler as used
        self.jobs[job_id].last_usage = time.time()
        # The format is a list of dictionaries
        self.jobs[job_id].write("- %s" % message)

        if message_lvl == "results":
            try:
                job = TestJob.objects.get(pk=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return
            meta_filename = create_metadata_store(message_msg, job)
            new_test_case = map_scanned_results(results=message_msg, job=job,
                                                markers=self.jobs[job_id].markers,
                                                meta_filename=meta_filename)

            if new_test_case is None:
                self.logger.warning(
                    "[%s] unable to map scanned results: %s",
                    job_id, message)
            else:
                self.test_cases.append(new_test_case)

            # Look for lava.job result
            if message_msg.get("definition") == "lava" and message_msg.get("case") == "job":
                # Flush cached test cases
                self.flush_test_cases()

                if message_msg.get("result") == "pass":
                    health = TestJob.HEALTH_COMPLETE
                    health_msg = "Complete"
                else:
                    health = TestJob.HEALTH_INCOMPLETE
                    health_msg = "Incomplete"
                self.logger.info("[%s] job status: %s", job_id, health_msg)

                infrastructure_error = (message_msg.get("error_type") in ["Bug",
                                                                          "Configuration",
                                                                          "Infrastructure"])
                if infrastructure_error:
                    self.logger.info("[%s] Infrastructure error", job_id)

                # Update status.
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)
                    job.go_state_finished(health, infrastructure_error)
                    job.save()

        # n.b. logging here would produce a log entry for every message in every job.

    def controler_socket(self):
        msg = self.controler.recv_multipart()
        try:
            master_id = u(msg[0])
            action = u(msg[1])
            ping_interval = int(msg[2])

            if master_id != "master":
                self.logger.error("Invalid master id '%s'. Should be 'master'",
                                  master_id)
                return
            if action != "PONG":
                self.logger.error("Invalid answer '%s'. Should be 'PONG'",
                                  action)
                return
        except (IndexError, ValueError):
            self.logger.error("Invalid message '%s'", msg)
            return

        if ping_interval < TIMEOUT:
            self.logger.error("invalid ping interval (%d) too small", ping_interval)
            return

        self.logger.debug("master => PONG(%d)", ping_interval)
        self.ping_interval = ping_interval
Exemple #26
0
class Command(LAVADaemonCommand):
    """
    worker_host is the hostname of the worker this field is set by the admin
    and could therefore be empty in a misconfigured instance.
    """
    logger = None
    help = "LAVA dispatcher master"
    default_logfile = "/var/log/lava-server/lava-master.log"

    def __init__(self, *args, **options):
        super(Command, self).__init__(*args, **options)
        self.auth = None
        self.controler = None
        self.event_socket = None
        self.poller = None
        self.pipe_r = None
        self.inotify_fd = None
        # List of logs
        # List of known dispatchers. At startup do not load this from the
        # database. This will help to know if the slave as restarted or not.
        self.dispatchers = {"lava-logs": SlaveDispatcher("lava-logs", online=False)}
        self.events = {"canceling": set()}

    def add_arguments(self, parser):
        super(Command, self).add_arguments(parser)
        # Important: ensure share/env.yaml is put into /etc/ by setup.py in packaging.
        config = parser.add_argument_group("dispatcher config")

        config.add_argument('--env',
                            default="/etc/lava-server/env.yaml",
                            help="Environment variables for the dispatcher processes. "
                                 "Default: /etc/lava-server/env.yaml")
        config.add_argument('--env-dut',
                            default="/etc/lava-server/env.dut.yaml",
                            help="Environment variables for device under test. "
                                 "Default: /etc/lava-server/env.dut.yaml")
        config.add_argument('--dispatchers-config',
                            default="/etc/lava-server/dispatcher.d",
                            help="Directory that might contain dispatcher specific configuration")

        net = parser.add_argument_group("network")
        net.add_argument('--master-socket',
                         default='tcp://*:5556',
                         help="Socket for master-slave communication. Default: tcp://*:5556")
        net.add_argument('--event-url', default="tcp://localhost:5500",
                         help="URL of the publisher")
        net.add_argument('--ipv6', default=False, action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt', default=False, action='store_true',
                         help="Encrypt messages")
        net.add_argument('--master-cert',
                         default='/etc/lava-dispatcher/certificates.d/master.key_secret',
                         help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def send_status(self, hostname):
        """
        The master crashed, send a STATUS message to get the current state of jobs
        """
        jobs = TestJob.objects.filter(actual_device__worker_host__hostname=hostname,
                                      state=TestJob.STATE_RUNNING)
        for job in jobs:
            self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname,
                             job.actual_device.hostname)
            send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)])

    def dispatcher_alive(self, hostname):
        if hostname not in self.dispatchers:
            # The server crashed: send a STATUS message
            self.logger.warning("Unknown dispatcher <%s> (server crashed)", hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)
            self.send_status(hostname)

        # Mark the dispatcher as alive
        self.dispatchers[hostname].alive()

    def controler_socket(self):
        try:
            # We need here to use the zmq.NOBLOCK flag, otherwise we could block
            # the whole main loop where this function is called.
            msg = self.controler.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False
        # This is way to verbose for production and should only be activated
        # by (and for) developers
        # self.logger.debug("[CC] Receiving: %s", msg)

        # 1: the hostname (see ZMQ documentation)
        hostname = u(msg[0])
        # 2: the action
        action = u(msg[1])

        # Check that lava-logs only send PINGs
        if hostname == "lava-logs" and action != "PING":
            self.logger.error("%s => %s Invalid action from log daemon",
                              hostname, action)
            return True

        # Handle the actions
        if action == 'HELLO' or action == 'HELLO_RETRY':
            self._handle_hello(hostname, action, msg)
        elif action == 'PING':
            self._handle_ping(hostname, action, msg)
        elif action == 'END':
            self._handle_end(hostname, action, msg)
        elif action == 'START_OK':
            self._handle_start_ok(hostname, action, msg)
        else:
            self.logger.error("<%s> sent unknown action=%s, args=(%s)",
                              hostname, action, msg[1:])
        return True

    def read_event_socket(self):
        try:
            msg = self.event_socket.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False

        try:
            (topic, _, dt, username, data) = (u(m) for m in msg)
        except ValueError:
            self.logger.error("Invalid event: %s", msg)
            return True

        if topic.endswith(".testjob"):
            try:
                data = simplejson.loads(data)
                if data["state"] == "Canceling":
                    self.events["canceling"].add(int(data["job"]))
            except ValueError:
                self.logger.error("Invalid event data: %s", msg)
        return True

    def _handle_end(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
            error_msg = msg[3]
            compressed_description = msg[4]
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        try:
            job = TestJob.objects.get(id=job_id)
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
            # ACK even if the job is unknown to let the dispatcher
            # forget about it
            send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
            return

        filename = os.path.join(job.output_dir, 'description.yaml')
        # If description.yaml already exists: a END was already received
        if os.path.exists(filename):
            self.logger.info("[%d] %s => END (duplicated), skipping", job_id, hostname)
        else:
            if compressed_description:
                self.logger.info("[%d] %s => END", job_id, hostname)
            else:
                self.logger.info("[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)",
                                 job_id, hostname)
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)

                    job.go_state_finished(TestJob.HEALTH_INCOMPLETE)
                    if error_msg:
                        self.logger.error("[%d] Error: %s", job_id, error_msg)
                        job.failure_comment = error_msg
                    job.save()

            # Create description.yaml even if it's empty
            # Allows to know when END messages are duplicated
            try:
                # Create the directory if it was not already created
                mkdir(os.path.dirname(filename))
                # TODO: check that compressed_description is not ""
                description = lzma.decompress(compressed_description)
                with open(filename, 'w') as f_description:
                    f_description.write(description.decode("utf-8"))
                if description:
                    parse_job_description(job)
            except (IOError, lzma.LZMAError) as exc:
                self.logger.error("[%d] Unable to dump 'description.yaml'",
                                  job_id)
                self.logger.exception("[%d] %s", job_id, exc)

        # ACK the job and mark the dispatcher as alive
        send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
        self.dispatcher_alive(hostname)

    def _handle_hello(self, hostname, action, msg):
        # Check the protocol version
        try:
            slave_version = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        self.logger.info("%s => %s", hostname, action)
        if slave_version != PROTOCOL_VERSION:
            self.logger.error("<%s> using protocol v%d while master is using v%d",
                              hostname, slave_version, PROTOCOL_VERSION)
            return

        send_multipart_u(self.controler, [hostname, 'HELLO_OK'])
        # If the dispatcher is known and sent an HELLO, means that
        # the slave has restarted
        if hostname in self.dispatchers:
            if action == 'HELLO':
                self.logger.warning("Dispatcher <%s> has RESTARTED",
                                    hostname)
            else:
                # Assume the HELLO command was received, and the
                # action succeeded.
                self.logger.warning("Dispatcher <%s> was not confirmed",
                                    hostname)
        else:
            # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO
            # message.
            self.logger.warning("New dispatcher <%s>", hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)

        # Mark the dispatcher as alive
        self.dispatcher_alive(hostname)

    def _handle_ping(self, hostname, action, msg):  # pylint: disable=unused-argument
        self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL)
        # Send back a signal
        send_multipart_u(self.controler, [hostname, 'PONG', str(PING_INTERVAL)])
        self.dispatcher_alive(hostname)

    def _handle_start_ok(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return
        self.logger.info("[%d] %s => START_OK", job_id, hostname)
        try:
            with transaction.atomic():
                # TODO: find a way to lock actual_device
                job = TestJob.objects.select_for_update() \
                                     .get(id=job_id)
                job.go_state_running()
                job.save()
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
        else:
            self.dispatcher_alive(hostname)

    def export_definition(self, job):  # pylint: disable=no-self-use
        job_def = yaml.load(job.definition)
        job_def['compatibility'] = job.pipeline_compatibility

        # no need for the dispatcher to retain comments
        return yaml.dump(job_def)

    def save_job_config(self, job, worker, device_cfg, options):
        output_dir = job.output_dir
        mkdir(output_dir)
        with open(os.path.join(output_dir, "job.yaml"), "w") as f_out:
            f_out.write(self.export_definition(job))
        with contextlib.suppress(IOError):
            shutil.copy(options["env"], os.path.join(output_dir, "env.yaml"))
        with contextlib.suppress(IOError):
            shutil.copy(options["env_dut"], os.path.join(output_dir, "env.dut.yaml"))
        with contextlib.suppress(IOError):
            shutil.copy(os.path.join(options["dispatchers_config"], "%s.yaml" % worker.hostname),
                        os.path.join(output_dir, "dispatcher.yaml"))
        with open(os.path.join(output_dir, "device.yaml"), "w") as f_out:
            yaml.dump(device_cfg, f_out)

    def start_job(self, job, options):
        # Load job definition to get the variables for template
        # rendering
        job_def = yaml.load(job.definition)
        job_ctx = job_def.get('context', {})

        device = job.actual_device
        worker = device.worker_host

        # Load configurations
        env_str = load_optional_yaml_file(options['env'])
        env_dut_str = load_optional_yaml_file(options['env_dut'])
        device_cfg = device.load_configuration(job_ctx)
        dispatcher_cfg_file = os.path.join(options['dispatchers_config'],
                                           "%s.yaml" % worker.hostname)
        dispatcher_cfg = load_optional_yaml_file(dispatcher_cfg_file)

        self.save_job_config(job, worker, device_cfg, options)
        self.logger.info("[%d] START => %s (%s)", job.id,
                         worker.hostname, device.hostname)
        send_multipart_u(self.controler,
                         [worker.hostname, 'START', str(job.id),
                          self.export_definition(job),
                          yaml.dump(device_cfg),
                          dispatcher_cfg, env_str, env_dut_str])

        # For multinode jobs, start the dynamic connections
        parent = job
        for sub_job in job.sub_jobs_list:
            if sub_job == parent or not sub_job.dynamic_connection:
                continue

            # inherit only enough configuration for dynamic_connection operation
            self.logger.info("[%d] Trimming dynamic connection device configuration.", sub_job.id)
            min_device_cfg = parent.actual_device.minimise_configuration(device_cfg)

            self.save_job_config(sub_job, worker, min_device_cfg, options)
            self.logger.info("[%d] START => %s (connection)",
                             sub_job.id, worker.hostname)
            send_multipart_u(self.controler,
                             [worker.hostname, 'START',
                              str(sub_job.id),
                              self.export_definition(sub_job),
                              yaml.dump(min_device_cfg), dispatcher_cfg,
                              env_str, env_dut_str])

    def start_jobs(self, options):
        """
        Loop on all scheduled jobs and send the START message to the slave.
        """
        # make the request atomic
        query = TestJob.objects.select_for_update()
        # Only select test job that are ready
        query = query.filter(state=TestJob.STATE_SCHEDULED)
        # Only start jobs on online workers
        query = query.filter(actual_device__worker_host__state=Worker.STATE_ONLINE)
        # exclude test job without a device: they are special test jobs like
        # dynamic connection.
        query = query.exclude(actual_device=None)
        # TODO: find a way to lock actual_device

        # Loop on all jobs
        for job in query:
            msg = None
            try:
                self.start_job(job, options)
            except jinja2.TemplateNotFound as exc:
                self.logger.error("[%d] Template not found: '%s'",
                                  job.id, exc.message)
                msg = "Template not found: '%s'" % exc.message
            except jinja2.TemplateSyntaxError as exc:
                self.logger.error("[%d] Template syntax error in '%s', line %d: %s",
                                  job.id, exc.name, exc.lineno, exc.message)
                msg = "Template syntax error in '%s', line %d: %s" % (exc.name, exc.lineno, exc.message)
            except IOError as exc:
                self.logger.error("[%d] Unable to read '%s': %s",
                                  job.id, exc.filename, exc.strerror)
                msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror)
            except yaml.YAMLError as exc:
                self.logger.error("[%d] Unable to parse job definition: %s",
                                  job.id, exc)
                msg = "Cannot parse job definition: %s" % exc

            if msg:
                # Add the error as lava.job result
                metadata = {"case": "job",
                            "definition": "lava",
                            "error_type": "Infrastructure",
                            "error_msg": msg,
                            "result": "fail"}
                suite, _ = TestSuite.objects.get_or_create(name="lava", job=job)
                TestCase.objects.create(name="job", suite=suite, result=TestCase.RESULT_FAIL,
                                        metadata=yaml.dump(metadata))
                job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True)
                job.save()

    def cancel_jobs(self, partial=False):
        query = TestJob.objects.filter(state=TestJob.STATE_CANCELING)
        if partial:
            query = query.filter(id__in=list(self.events["canceling"]))

        for job in query:
            worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host
            self.logger.info("[%d] CANCEL => %s", job.id,
                             worker.hostname)
            send_multipart_u(self.controler,
                             [worker.hostname, 'CANCEL', str(job.id)])

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-master", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        self.logger.info("[INIT] Marking all workers as offline")
        with transaction.atomic():
            for worker in Worker.objects.select_for_update().all():
                worker.go_state_offline()
                worker.save()

        # Create the sockets
        context = zmq.Context()
        self.controler = context.socket(zmq.ROUTER)
        self.event_socket = context.socket(zmq.SUB)

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.controler.setsockopt(zmq.IPV6, 1)
            self.event_socket.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs'])
                self.auth.configure_curve(domain='*', location=options['slaves_certs'])
            except IOError as err:
                self.logger.error(err)
                self.auth.stop()
                return
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_server = True

            self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
            self.inotify_fd = watch_directory(options["slaves_certs"])
            if self.inotify_fd is None:
                self.logger.error("[INIT] Unable to start inotify")

        self.controler.setsockopt(zmq.IDENTITY, b"master")
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42
        # "If two clients use the same identity when connecting to a ROUTER
        # [...] the ROUTER socket shall hand-over the connection to the new
        # client and disconnect the existing one."
        self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1)
        self.controler.bind(options['master_socket'])

        self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC))
        self.event_socket.connect(options['event_url'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.controler, zmq.POLLIN)
        self.poller.register(self.event_socket, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] LAVA master has started.")
        self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION)

        try:
            self.main_loop(options)
        except BaseException as exc:
            self.logger.error("[CLOSE] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Drop controler socket: the protocol does handle lost messages
            self.logger.info("[CLOSE] Closing the controler socket and dropping messages")
            self.controler.close(linger=0)
            self.event_socket.close(linger=0)
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def main_loop(self, options):
        last_schedule = last_dispatcher_check = time.time()

        while True:
            try:
                try:
                    # Compute the timeout
                    now = time.time()
                    timeout = min(SCHEDULE_INTERVAL - (now - last_schedule),
                                  PING_INTERVAL - (now - last_dispatcher_check))
                    # If some actions are remaining, decrease the timeout
                    if self.events["canceling"]:
                        timeout = min(timeout, 1)
                    # Wait at least for 1ms
                    timeout = max(timeout * 1000, 1)

                    # Wait for data or a timeout
                    sockets = dict(self.poller.poll(timeout))
                except zmq.error.ZMQError:
                    continue

                if sockets.get(self.pipe_r) == zmq.POLLIN:
                    self.logger.info("[POLL] Received a signal, leaving")
                    break

                # Command socket
                if sockets.get(self.controler) == zmq.POLLIN:
                    while self.controler_socket():  # Unqueue all pending messages
                        pass

                # Events socket
                if sockets.get(self.event_socket) == zmq.POLLIN:
                    while self.read_event_socket():  # Unqueue all pending messages
                        pass
                    # Wait for the next iteration to handle the event.
                    # In fact, the code that generated the event (lava-logs or
                    # lava-server-gunicorn) needs some time to commit the
                    # database transaction.
                    # If we are too fast, the database object won't be
                    # available (or in the right state) yet.
                    continue

                # Inotify socket
                if sockets.get(self.inotify_fd) == zmq.POLLIN:
                    os.read(self.inotify_fd, 4096)
                    self.logger.debug("[AUTH] Reloading certificates from %s",
                                      options['slaves_certs'])
                    self.auth.configure_curve(domain='*', location=options['slaves_certs'])

                # Check dispatchers status
                now = time.time()
                if now - last_dispatcher_check > PING_INTERVAL:
                    for hostname, dispatcher in self.dispatchers.items():
                        if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT:
                            if hostname == "lava-logs":
                                self.logger.error("[STATE] lava-logs goes OFFLINE")
                            else:
                                self.logger.error("[STATE] Dispatcher <%s> goes OFFLINE", hostname)
                            self.dispatchers[hostname].go_offline()
                    last_dispatcher_check = now

                # Limit accesses to the database. This will also limit the rate of
                # CANCEL and START messages
                if time.time() - last_schedule > SCHEDULE_INTERVAL:
                    if self.dispatchers["lava-logs"].online:
                        schedule(self.logger)

                        # Dispatch scheduled jobs
                        with transaction.atomic():
                            self.start_jobs(options)
                    else:
                        self.logger.warning("lava-logs is offline: can't schedule jobs")

                    # Handle canceling jobs
                    self.cancel_jobs()

                    # Do not count the time taken to schedule jobs
                    last_schedule = time.time()
                else:
                    # Cancel the jobs and remove the jobs from the set
                    if self.events["canceling"]:
                        self.cancel_jobs(partial=True)
                        self.events["canceling"] = set()

            except (OperationalError, InterfaceError):
                self.logger.info("[RESET] database connection reset.")
                # Closing the database connection will force Django to reopen
                # the connection
                connection.close()
                time.sleep(2)
Exemple #27
0
class ZmqReceiver(ZmqBase):
    '''
    A ZmqReceiver class will listen on a REP or SUB socket for messages
    and will call the 'handle_incoming_message()' method to process it.
    Subclasses should override that. A response must be implemented for
    REP sockets, but is useless for SUB sockets.
    '''
    def __init__(self,
                 zmq_rep_bind_address: Optional[str] = None,
                 zmq_sub_connect_addresses: Tuple[SubSocketAddress,
                                                  ...] = None,
                 recreate_timeout: Optional[int] = 600,
                 username: Optional[str] = None,
                 password: Optional[str] = None):
        super().__init__()
        self.__context = zmq.Context()
        self.__poller = zmq.Poller()

        self.__sub_sockets = tuple(
            SubSocket(
                ctx=self.__context,
                poller=self.__poller,
                address=address,
                timeout_in_sec=recreate_timeout,
            ) for address in (zmq_sub_connect_addresses or tuple()))

        self.__auth: Optional[ThreadAuthenticator] = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I know (probably because the
            # more secure solutions require two way communication as well)
            self.__auth = ThreadAuthenticator(self.__context)
            self.__auth.start()

            # Instruct authenticator to handle PLAIN requests
            self.__auth.configure_plain(domain='*',
                                        passwords={
                                            username: password,
                                        })

        self.__rep_socket = RepSocket(
            ctx=self.__context,
            poller=self.__poller,
            address=zmq_rep_bind_address,
            auth=self.__auth,
        ) if zmq_rep_bind_address else None

        self.__last_received_message = None
        self.__is_running = False

    @property
    def is_running(self) -> bool:
        return self.__is_running

    def stop(self) -> None:
        '''
        May take up to 60 seconds to actually stop since poller has timeout of
        60 seconds
        '''

        self._info('Closing pub and sub sockets...')
        self.__is_running = False

        if self.__auth is not None:
            self.__auth.stop()

    def _run_rep_socket(self, socks) -> None:
        if self.__rep_socket is None:
            return

        incoming_message = self.__rep_socket.recv_string(socks)
        if incoming_message is None:
            return

        if incoming_message != self.HEARTBEAT_MSG:
            self.__last_received_message = incoming_message

        self._debug('Got info from REP socket')

        try:
            response_message = self.handle_incoming_message(incoming_message, )
            self.__rep_socket.send(response_message)
        except Exception as e:
            self._error(e)

    def _run_sub_sockets(self, socks) -> None:
        for sub_socket in self.__sub_sockets:
            incoming_message = sub_socket.recv_string(socks)

            if incoming_message is None:
                continue

            if incoming_message != self.HEARTBEAT_MSG:
                self.__last_received_message = incoming_message

            self._debug('Got info from SUB socket')

            try:
                self.handle_incoming_message(incoming_message)
            except Exception as e:
                self._error(e)

    def run(self) -> None:
        self.__is_running = True

        while self.__is_running:
            try:
                socks = dict(self.__poller.poll(1000))
            except BaseException as e:
                self._error(e)
                continue

            self._debug('Poll cycle over. checking sockets')
            self._run_rep_socket(socks)
            self._run_sub_sockets(socks)

        if self.__rep_socket:
            self.__rep_socket.destroy()

        for sub_socket in self.__sub_sockets:
            sub_socket.destroy()

    def create_response_message(self,
                                status_code: int,
                                status_message: str,
                                response_message: Optional[str] = None) -> str:
        payload = {
            self.STATUS_CODE: status_code,
            self.STATUS_MSG: status_message,
        }

        if response_message is not None:
            payload[self.RESPONSE_MSG] = response_message

        return json.dumps(payload)

    def handle_incoming_message(self, message: str) -> Optional[str]:
        if message == self.HEARTBEAT_MSG:
            return None

        return self.create_response_message(
            status_code=self.STATUS_CODE_OK,
            status_message=self.STATUS_MSG_OK,
        )

    def get_last_received_message(self) -> Optional[str]:
        return self.__last_received_message

    def get_sub_socket(self, idx: int) -> SubSocket:
        return self.__sub_sockets[idx]
Exemple #28
0
class MultiNodeAgent(BEMOSSAgent):
    '''Listens to everything and publishes a heartbeat according to the
    heartbeat period specified in the settings module.
    '''
    def __init__(self, config_path, **kwargs):
        super(MultiNodeAgent, self).__init__(**kwargs)
        #self.node_health = dict()
        #self.node_last_sync = dict()

        self.agent_id = 'multinodeagent'
        self.identity = self.agent_id

        self.multinode_status = dict()
        self.is_parent = False
        self.last_sync_with_parent = datetime(1991, 1,
                                              1)  #equivalent to -ve infinitive
        self.parent_node = None
        self.curcon = None  #initialize database connection.
        self.recently_online_node_list = []  # initialize to lists to empty
        self.recently_offline_node_list = [
        ]  # they will be filled as nodes are discovered to be online/offline

        self.offline_variables = offline_variables
        self.offline_variables['logged_by'] = self.agent_id
        self.offline_table = offline_table
        self.offline_log_variables = offline_log_variables

    def getMultinodeData(self):
        self.multinode_data = db_helper.get_multinode_data()

        self.nodelist_dict = {
            node['name']: node
            for node in self.multinode_data['known_nodes']
        }
        self.node_name_list = [
            node['name'] for node in self.multinode_data['known_nodes']
        ]
        self.address_list = [
            node['address'] for node in self.multinode_data['known_nodes']
        ]
        self.server_key_list = [
            node['server_key'] for node in self.multinode_data['known_nodes']
        ]
        self.node_name = self.multinode_data['this_node']

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == self.node_name:
                self.node_index = index
                break
        else:
            raise ValueError(
                '"this_node:" entry on the multinode_data json file is invalid'
            )

        for node_name in self.node_name_list:  #initialize all nodes data
            if node_name not in self.multinode_status:  #initialize new nodes. There could be already the node if this getMultiNode
                # data is being called later
                self.multinode_status[node_name] = dict()
                self.multinode_status[node_name][
                    'health'] = -10  #initialized; never online/offline
                self.multinode_status[node_name]['last_sync_time'] = datetime(
                    1991, 1, 1)
                self.multinode_status[node_name]['last_online_time'] = None
                self.multinode_status[node_name]['last_offline_time'] = None
                self.multinode_status[node_name]['last_scanned_time'] = None

    def configure_authenticator(self):
        self.auth.allow()
        # Tell authenticator to use the certificate in a directory
        self.auth.configure_curve(domain='*', location=self.public_keys_dir)

    @Core.receiver('onsetup')
    def onsetup(self, sender, **kwargs):
        print "Setup"
        self.getMultinodeData()

        base_dir = settings.PROJECT_DIR + "/Agents/MultiNodeAgent/"
        public_keys_dir = os.path.abspath(os.path.join(base_dir,
                                                       'public_keys'))
        secret_keys_dir = os.path.abspath(
            os.path.join(base_dir, 'private_keys'))

        self.secret_keys_dir = secret_keys_dir
        self.public_keys_dir = public_keys_dir

        if not (os.path.exists(public_keys_dir)
                and os.path.exists(secret_keys_dir)):
            logging.critical(
                "Certificates are missing - run generate_certificates.py script first"
            )
            sys.exit(1)

        ctx = zmq.Context.instance()
        self.ctx = ctx
        # Start an authenticator for this context.
        self.auth = ThreadAuthenticator(ctx)
        self.auth.start()
        self.configure_authenticator()

        server = ctx.socket(zmq.PUB)

        server_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['server_secret_key']
        server_secret_file = os.path.join(secret_keys_dir,
                                          server_secret_key_filename)
        server_public, server_secret = zmq.auth.load_certificate(
            server_secret_file)
        server.curve_secretkey = server_secret
        server.curve_publickey = server_public
        server.curve_server = True  # must come before bind
        server.bind(
            self.multinode_data['known_nodes'][self.node_index]['address'])
        self.server = server

    def check_if_parent(self):
        if self.node_name == self.node_name_list[
                0]:  #The first entry is the parent; always
            self.is_parent = True
            self.node_index = 0
            print "I am the boss now, " + self.node_name
            # start the web-server
            subprocess.check_output(settings.PROJECT_DIR +
                                    "/start_webserver.sh " +
                                    settings.PROJECT_DIR,
                                    shell=True)
            message = dict()
            message[STATUS_CHANGE.AGENT_ID] = 'devicediscoveryagent'
            message[STATUS_CHANGE.NODE] = str(self.node_index)
            message[STATUS_CHANGE.AGENT_STATUS] = 'start'
            message[STATUS_CHANGE.
                    NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT
            self.bemoss_publish('status_change', 'networkagent', [message])
            self.updateParent(self.node_name)
            print "discoveryagent started"

    def disperseMessage(self, topic, header, message):
        for node_name in self.node_name_list:
            if node_name == self.node_name:
                continue
            self.server.send(
                jsonify(node_name + '/republish/' + topic, message))

    def republishToParent(self, topic, header, message):
        if self.is_parent:
            return  #if I am parent, the message is already published
        for node_name in self.node_name_list:
            if self.multinode_status[node_name][
                    'health'] == 2:  #health = 2 is the parent node
                self.server.send(
                    jsonify(node_name + '/republish/' + topic, message))

    @Core.periodic(20)
    def send_heartbeat(self):
        # self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'})
        # print 'publishing'
        print "Sending heartbeat"

        last_sync_string = self.last_sync_with_parent.strftime(
            '%B %d, %Y, %H:%M:%S')
        self.server.send(
            jsonify(
                'heartbeat/' + self.node_name + '/' + str(self.is_parent) +
                '/' + last_sync_string, ""))

    def extract_ip(self, addr):
        return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})',
                         addr).groups()[0]

    def getNodeId(self, node_name):

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == node_name:
                node_index = index
                break
        else:
            raise ValueError('the node name: ' + node_name +
                             ' is not found in multinode data')

        return node_index

    def getNodeName(self, node_id):
        return self.multinode_data['known_nodes'][node_id]['name']

    def handle_offline_nodes(self, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that node on this node
            command_group = []
            for node_name in node_name_list:
                node_id = self.getNodeId(node_name)
                #put the offline event into cassandra events log table, and also create notification
                self.offline_variables['date_id'] = str(datetime.now().date())
                self.offline_variables['time'] = datetime.utcnow()
                self.offline_variables['agent_id'] = node_name
                self.offline_variables['event'] = 'node-offline'
                self.offline_variables['reason'] = 'communication-error'
                self.offline_variables['related_to'] = None
                self.offline_variables['event_id'] = uuid.uuid4()
                self.offline_variables['logged_time'] = datetime.utcnow()
                self.TSDCustomInsert(all_vars=self.offline_variables,
                                     log_vars=self.offline_log_variables,
                                     tablename=self.offline_table)
                time = date_converter.UTCToLocal(datetime.utcnow())
                message = str(
                    node_name
                ) + ': ' + 'node-offline. Reason: possibly communiation-error'
                self.curcon.execute(
                    "select id from possible_events where event_name=%s",
                    ('node-offline', ))
                event_id = self.curcon.fetchone()[0]
                self.curcon.execute(
                    "insert into notification (dt_triggered, seen, event_type_id, message) VALUES (%s, %s, %s, %s)",
                    (time, False, event_id, message))
                self.curcon.commit()

                # get a list of agents that were supposedly running in that offline node
                self.curcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))

                if self.curcon.rowcount:
                    agent_ids = self.curcon.fetchall()

                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.AGENT_STATUS] = 'start'
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY
                        command_group += [message]
            print "moving agents from offline node to parent: " + str(
                node_name_list)
            print command_group
            if command_group:
                self.bemoss_publish('status_change', 'networkagent',
                                    command_group)

    def handle_online_nodes(self, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that nodes back on them
            command_group = []
            for node_name in node_name_list:

                node_id = self.getNodeId(node_name)

                # put the online event into cassandra events log table, and also create notification
                self.offline_variables['date_id'] = str(datetime.now().date())
                self.offline_variables['time'] = datetime.utcnow()
                self.offline_variables['agent_id'] = node_name
                self.offline_variables['event'] = 'node-online'
                self.offline_variables['reason'] = 'communication-restored'
                self.offline_variables['related_to'] = None
                self.offline_variables['event_id'] = uuid.uuid4()
                self.offline_variables['logged_time'] = datetime.utcnow()
                self.TSDCustomInsert(all_vars=self.offline_variables,
                                     log_vars=self.offline_log_variables,
                                     tablename=self.offline_table)
                time = date_converter.UTCToLocal(datetime.utcnow())
                message = str(
                    node_name
                ) + ': ' + 'node-online. Reason: possibly communiation-restored'
                self.curcon.execute(
                    "select id from possible_events where event_name=%s",
                    ('node-online', ))
                event_id = self.curcon.fetchone()[0]
                self.curcon.execute(
                    "insert into notification (dt_triggered, seen, event_type_id, message) VALUES (%s, %s, %s, %s)",
                    (time, False, event_id, message))
                self.curcon.commit()

                #get a list of agents that were supposed to be running in that online node
                self.curcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))
                if self.curcon.rowcount:
                    agent_ids = self.curcon.fetchall()
                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.
                                AGENT_STATUS] = 'stop'  #stop in this node
                        command_group += [message]
                        message = dict(message)  #create another copy
                        message[STATUS_CHANGE.NODE] = str(node_id)
                        message[
                            STATUS_CHANGE.
                            AGENT_STATUS] = 'start'  #start in the target node
                        command_group += [message]

            print "Moving agents back to the online node: " + str(
                node_name_list)
            print command_group

            if command_group:
                self.bemoss_publish('status_change', 'networkagent',
                                    command_group)

    def updateParent(self, parent_node_name):
        parent_ip = self.extract_ip(
            self.nodelist_dict[parent_node_name]['address'])
        write_new = False
        if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE
                              ):  # but parent file doesn't exists
            write_new = True
        else:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f:
                read_ip = f.read()
            if read_ip != parent_ip:
                write_new = True
        if write_new:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f:
                f.write(parent_ip)
            if self.curcon:
                self.curcon.close()  #close old connection
            self.curcon = db_connection(
            )  #start new connection using new parent_ip
            self.vip.pubsub.publish('pubsub',
                                    'from/multinodeagent/update_parent')

    @Core.periodic(60)
    def check_health(self):

        for node_name, node in self.multinode_status.items():
            if node['health'] > 0:  #initialize all online nodes to 0. If they are really online, they should change it
                #  back to 1 or 2 (parent) within 30 seconds throught the heartbeat.
                node['health'] = 0

        gevent.sleep(30)
        parent_node_name = None  #initialize parent node
        online_node_exists = False
        for node_name, node in self.multinode_status.items():
            node['last_scanned_time'] = datetime.now()
            if node['health'] == 0:
                node['health'] = -1
                node['last_offline_time'] = datetime.now()
                self.recently_offline_node_list += [node_name]
            elif node['health'] == -1:  #offline since long
                pass
            elif node[
                    'health'] == -10:  #The node was initialized to -10, and never came online. Treat it as recently going
                # offline for this iteration so that the agents that were supposed to be running there can be migrated
                node['health'] = -1
                self.recently_offline_node_list += [node_name]
            elif node['health'] == 2:  #there is some parent node present
                parent_node_name = node_name
            if node['health'] > 0:
                online_node_exists = True  #At-least one node (itself) should be online, if not some problem

        assert online_node_exists, "At least one node (current node) must be online"
        if parent_node_name:  #parent node exist
            self.updateParent(parent_node_name)

        for node in self.multinode_data['known_nodes']:
            print node['name'] + ': ' + str(
                self.multinode_status[node['name']]['health'])

        if self.is_parent:
            #if this is a parent node, update the node_info table
            if self.curcon is None:  #if no database connection exists make connection
                self.curcon = db_connection()

            tbl_node_info = settings.DATABASES['default']['TABLE_node_info']
            self.curcon.execute('select node_id from ' + tbl_node_info)
            to_be_deleted_node_ids = self.curcon.fetchall()
            for index, node in enumerate(self.multinode_data['known_nodes']):
                if (index, ) in to_be_deleted_node_ids:
                    to_be_deleted_node_ids.remove(
                        (index, ))  #don't remove this current node
                result = self.curcon.execute(
                    'select * from ' + tbl_node_info + ' where node_id=%s',
                    (index, ))
                node_type = 'parent' if self.multinode_status[
                    node['name']]['health'] == 2 else "child"
                node_status = "ONLINE" if self.multinode_status[
                    node['name']]['health'] > 0 else "OFFLINE"
                ip_address = self.extract_ip(node['address'])
                last_scanned_time = self.multinode_status[
                    node['name']]['last_online_time']
                last_offline_time = self.multinode_status[
                    node['name']]['last_offline_time']
                last_sync_time = self.multinode_status[
                    node['name']]['last_sync_time']

                var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)"
                value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)"
                actual_values_list = (index, node['name'], node_type,
                                      node_status, ip_address,
                                      last_scanned_time, last_offline_time,
                                      last_sync_time)

                if self.curcon.rowcount == 0:
                    self.curcon.execute(
                        "insert into " + tbl_node_info + " " + var_list +
                        " VALUES" + value_placeholder_list, actual_values_list)
                else:
                    self.curcon.execute(
                        "update " + tbl_node_info + " SET " + var_list +
                        " = " + value_placeholder_list + " where node_id = %s",
                        actual_values_list + (index, ))
            self.curcon.commit()

            for id in to_be_deleted_node_ids:
                self.curcon.execute(
                    'delete from accounts_userprofile_nodes where nodeinfo_id=%s',
                    id)  #delete entries in user-profile for the old node
                self.curcon.commit()
                self.curcon.execute('delete from ' + tbl_node_info +
                                    ' where node_id=%s',
                                    id)  #delete the old nodes
                self.curcon.commit()

            if self.recently_online_node_list:  #Online nodes should be handled first because, the same node can first be
                #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after
                #coming online
                self.handle_online_nodes(self.recently_online_node_list)
                self.recently_online_node_list = []  # reset after handling
            if self.recently_offline_node_list:
                self.handle_offline_nodes(self.recently_offline_node_list)
                self.recently_offline_node_list = []  # reset after handling

    def connect_client(self, node):
        server_public_file = os.path.join(self.public_keys_dir,
                                          node['server_key'])
        server_public, _ = zmq.auth.load_certificate(server_public_file)
        # The client must know the server's public key to make a CURVE connection.
        self.client.curve_serverkey = server_public
        self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/')
        self.client.setsockopt(zmq.SUBSCRIBE, self.node_name)
        self.client.connect(node['address'])

    def disconnect_client(self, node):
        self.client.disconnect(node['address'])

    @Core.receiver('onstart')
    def onstart(self, sender, **kwargs):

        self.check_if_parent()
        print "Starting to receive Heart-beat"
        self.vip.heartbeat.start_with_period(15)
        client = self.ctx.socket(zmq.SUB)
        # We need two certificates, one for the client and one for
        # the server. The client must know the server's public key
        # to make a CURVE connection.

        client_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['client_secret_key']
        client_secret_file = os.path.join(self.secret_keys_dir,
                                          client_secret_key_filename)
        client_public, client_secret = zmq.auth.load_certificate(
            client_secret_file)
        client.curve_secretkey = client_secret
        client.curve_publickey = client_public

        self.client = client

        for node in self.multinode_data['known_nodes']:
            self.connect_client(node)

        print "Starting to listen"
        try:
            while True:  #read messages
                if client.poll(1000):
                    topic, msg = dejsonify(client.recv())
                    topic_list = topic.split('/')
                    if topic_list[0] == 'heartbeat':
                        node_name = topic_list[1]
                        is_parent = topic_list[2]
                        last_sync_with_parent = topic_list[3]
                        if self.multinode_status[node_name][
                                'health'] < 0:  #the node health was <0 , means offline
                            print node_name + " is back online"
                            self.recently_online_node_list += [node_name]

                        if is_parent.lower() in ['false', '0']:
                            self.multinode_status[node_name]['health'] = 1
                        elif is_parent.lower() in ['true', '1']:
                            self.multinode_status[node_name]['health'] = 2
                            self.parent_node = node_name
                        else:
                            raise ValueError(
                                'Invalid is_parent string in heart-beat message'
                            )

                        self.multinode_status[node_name][
                            'last_online_time'] = datetime.now()

                    if topic_list[0] == self.node_name:
                        #message addressed to this node

                        if topic_list[1] == 'republish':
                            new_topic = '/'.join(
                                topic_list[2:] +
                                ['repub-by-' + self.node_name, 'republished'])
                            self.vip.pubsub.publish('pubsub', new_topic, None,
                                                    msg)

                    print self.node_name + ": " + topic, str(msg)

                else:
                    gevent.sleep(2)

        except Exception as er:
            print "error"
            print er

        # stop auth thread
        self.auth.stop()

    @PubSub.subscribe('pubsub', 'to/multinodeagent/')
    def updateMultinodeData(self, peer, sender, bus, topic, headers, message):
        print "Updating Multinode data"
        topic_list = topic.split('/')
        self.configure_authenticator()
        #to/multinodeagent/from/<doesn't matter>/update_multinode_data
        if topic_list[4] == 'update_multinode_data':
            old_multinode_data = self.multinode_data
            self.getMultinodeData()
            for node in self.multinode_data['known_nodes']:
                if node not in old_multinode_data['known_nodes']:
                    print "New node has been added to the cluster: " + node[
                        'name']
                    print "We will connect to it"
                    self.connect_client(node)

            for node in old_multinode_data['known_nodes']:
                if node not in self.multinode_data['known_nodes']:
                    print "Node has been removed from the cluster: " + node[
                        'name']
                    print "We will disconnect from it"
                    self.disconnect_client(node)
                    # TODO: remove it from the node_info table

        print "yay! got it"

    @PubSub.subscribe('pubsub', 'to/')
    def relayToMessage(self, peer, sender, bus, topic, headers, message):
        print topic
        topic_list = topic.split('/')
        #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui>
        to_index = topic_list.index('to') + 1
        if 'from' in topic_list:
            from_index = topic_list.index('from') + 1
            from_entity = topic_list[from_index]

        to_entity = topic_list[to_index]
        last_field = topic_list[-1]
        if last_field == 'republished':  #it is already a republished message, no need to republish
            return
        if to_entity in settings.SYSTEM_AGENTS:
            self.disperseMessage(topic, headers,
                                 message)  #republish to all nodes
        elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS:
            if not self.is_parent:
                self.republishToParent(topic, headers, message)
        else:
            self.curcon.execute(
                "SELECT current_node_id FROM " + node_devices_table +
                " WHERE agent_id=%s", (to_entity, ))
            if self.curcon.rowcount:
                node_id = self.curcon.fetchone()[0]
                if node_id != self.node_index:
                    self.server.send(
                        jsonify(
                            self.getNodeName(node_id) + '/republish/' + topic,
                            message))

    @PubSub.subscribe('pubsub', 'from/')
    def relayFromMessage(self, peer, sender, bus, topic, headers, message):
        topic_list = topic.split('/')
        #from/<some_agent_or_ui>/topic
        from_entity = topic_list[1]
        last_field = topic_list[-1]
        if last_field == 'republished':  #it is a republished message, no need to publish
            return
        self.disperseMessage(topic, headers, message)  #republish to all nodes

    @PubSub.subscribe('pubsub', '')
    def on_match(self, peer, sender, bus, topic, headers, message):
        '''Use match_all to receive all messages and print them out.'''
        if sender == 'pubsub.compat':
            message = compat.unpack_legacy_message(headers, message)
Exemple #29
0
class Command(LAVADaemonCommand):
    """
    worker_host is the hostname of the worker this field is set by the admin
    and could therefore be empty in a misconfigured instance.
    """
    logger = None
    help = "LAVA dispatcher master"
    default_logfile = "/var/log/lava-server/lava-master.log"

    def __init__(self, *args, **options):
        super().__init__(*args, **options)
        self.auth = None
        self.controler = None
        self.event_socket = None
        self.poller = None
        self.pipe_r = None
        self.inotify_fd = None
        # List of logs
        # List of known dispatchers. At startup do not load this from the
        # database. This will help to know if the slave as restarted or not.
        self.dispatchers = {
            "lava-logs": SlaveDispatcher("lava-logs", online=False)
        }
        self.events = {"canceling": set(), "available_dt": set()}

    def add_arguments(self, parser):
        super().add_arguments(parser)
        net = parser.add_argument_group("network")
        net.add_argument(
            '--master-socket',
            default='tcp://*:5556',
            help="Socket for master-slave communication. Default: tcp://*:5556"
        )
        net.add_argument('--event-url',
                         default="tcp://localhost:5500",
                         help="URL of the publisher")
        net.add_argument('--ipv6',
                         default=False,
                         action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt',
                         default=False,
                         action='store_true',
                         help="Encrypt messages")
        net.add_argument(
            '--master-cert',
            default='/etc/lava-dispatcher/certificates.d/master.key_secret',
            help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def send_status(self, hostname):
        """
        The master crashed, send a STATUS message to get the current state of jobs
        """
        jobs = TestJob.objects.filter(
            actual_device__worker_host__hostname=hostname,
            state=TestJob.STATE_RUNNING)
        for job in jobs:
            self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname,
                             job.actual_device.hostname)
            send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)])

    def dispatcher_alive(self, hostname):
        if hostname not in self.dispatchers:
            # The server crashed: send a STATUS message
            self.logger.warning("Unknown dispatcher <%s> (server crashed)",
                                hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)
            self.send_status(hostname)

        # Mark the dispatcher as alive
        self.dispatchers[hostname].alive()

    def controler_socket(self):
        try:
            # We need here to use the zmq.NOBLOCK flag, otherwise we could block
            # the whole main loop where this function is called.
            msg = self.controler.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False
        # This is way to verbose for production and should only be activated
        # by (and for) developers
        # self.logger.debug("[CC] Receiving: %s", msg)

        # 1: the hostname (see ZMQ documentation)
        hostname = u(msg[0])
        # 2: the action
        action = u(msg[1])

        # Check that lava-logs only send PINGs
        if hostname == "lava-logs" and action != "PING":
            self.logger.error("%s => %s Invalid action from log daemon",
                              hostname, action)
            return True

        # Handle the actions
        if action == 'HELLO' or action == 'HELLO_RETRY':
            self._handle_hello(hostname, action, msg)
        elif action == 'PING':
            self._handle_ping(hostname, action, msg)
        elif action == 'END':
            self._handle_end(hostname, action, msg)
        elif action == 'START_OK':
            self._handle_start_ok(hostname, action, msg)
        else:
            self.logger.error("<%s> sent unknown action=%s, args=(%s)",
                              hostname, action, msg[1:])
        return True

    def read_event_socket(self):
        try:
            msg = self.event_socket.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False

        try:
            (topic, _, dt, username, data) = (u(m) for m in msg)
            data = simplejson.loads(data)
        except ValueError:
            self.logger.error("Invalid event: %s", msg)
            return True

        if topic.endswith(".testjob"):
            if data["state"] == "Canceling":
                self.events["canceling"].add(int(data["job"]))
            elif data["state"] == "Submitted":
                if "device_type" in data:
                    self.events["available_dt"].add(data["device_type"])
        elif topic.endswith(".device"):
            if data["state"] == "Idle" and data["health"] in [
                    "Good", "Unknown", "Looping"
            ]:
                self.events["available_dt"].add(data["device_type"])

        return True

    def _handle_end(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
            error_msg = msg[3]
            compressed_description = msg[4]
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        try:
            job = TestJob.objects.get(id=job_id)
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
            # ACK even if the job is unknown to let the dispatcher
            # forget about it
            send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
            return

        filename = os.path.join(job.output_dir, 'description.yaml')
        # If description.yaml already exists: a END was already received
        if os.path.exists(filename):
            self.logger.info("[%d] %s => END (duplicated), skipping", job_id,
                             hostname)
        else:
            if compressed_description:
                self.logger.info("[%d] %s => END", job_id, hostname)
            else:
                self.logger.info(
                    "[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)",
                    job_id, hostname)
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)

                    job.go_state_finished(TestJob.HEALTH_INCOMPLETE)
                    if error_msg:
                        self.logger.error("[%d] Error: %s", job_id, error_msg)
                        job.failure_comment = error_msg
                    job.save()

            # Create description.yaml even if it's empty
            # Allows to know when END messages are duplicated
            try:
                # Create the directory if it was not already created
                mkdir(os.path.dirname(filename))
                # TODO: check that compressed_description is not ""
                description = lzma.decompress(compressed_description)
                with open(filename, 'w') as f_description:
                    f_description.write(description.decode("utf-8"))
                if description:
                    parse_job_description(job)
            except (OSError, lzma.LZMAError) as exc:
                self.logger.error("[%d] Unable to dump 'description.yaml'",
                                  job_id)
                self.logger.exception("[%d] %s", job_id, exc)

        # ACK the job and mark the dispatcher as alive
        send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
        self.dispatcher_alive(hostname)

    def _handle_hello(self, hostname, action, msg):
        # Check the protocol version
        try:
            slave_version = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        self.logger.info("%s => %s", hostname, action)
        if slave_version != PROTOCOL_VERSION:
            self.logger.error(
                "<%s> using protocol v%d while master is using v%d", hostname,
                slave_version, PROTOCOL_VERSION)
            return

        send_multipart_u(self.controler, [hostname, 'HELLO_OK'])
        # If the dispatcher is known and sent an HELLO, means that
        # the slave has restarted
        if hostname in self.dispatchers:
            if action == 'HELLO':
                self.logger.warning("Dispatcher <%s> has RESTARTED", hostname)
            else:
                # Assume the HELLO command was received, and the
                # action succeeded.
                self.logger.warning("Dispatcher <%s> was not confirmed",
                                    hostname)
        else:
            # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO
            # message.
            self.logger.warning("New dispatcher <%s>", hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)

        # Mark the dispatcher as alive
        self.dispatcher_alive(hostname)

    def _handle_ping(self, hostname, action, msg):  # pylint: disable=unused-argument
        self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL)
        # Send back a signal
        send_multipart_u(
            self.controler,
            [hostname, 'PONG', str(PING_INTERVAL)])
        self.dispatcher_alive(hostname)

    def _handle_start_ok(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return
        self.logger.info("[%d] %s => START_OK", job_id, hostname)
        try:
            with transaction.atomic():
                # TODO: find a way to lock actual_device
                job = TestJob.objects.select_for_update() \
                                     .get(id=job_id)
                job.go_state_running()
                job.save()
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
        else:
            self.dispatcher_alive(hostname)

    def export_definition(self, job):  # pylint: disable=no-self-use
        job_def = yaml.safe_load(job.definition)
        job_def['compatibility'] = job.pipeline_compatibility

        # no need for the dispatcher to retain comments
        return yaml.dump(job_def)

    def save_job_config(self, job, device_cfg, env_str, env_dut_str,
                        dispatcher_cfg):
        output_dir = job.output_dir
        mkdir(output_dir)
        with open(os.path.join(output_dir, "job.yaml"), "w") as f_out:
            f_out.write(self.export_definition(job))
        with open(os.path.join(output_dir, "device.yaml"), "w") as f_out:
            yaml.dump(device_cfg, f_out)
        if env_str:
            with open(os.path.join(output_dir, "env.yaml"), "w") as f_out:
                f_out.write(env_str)
        if env_dut_str:
            with open(os.path.join(output_dir, "env.dut.yaml"), "w") as f_out:
                f_out.write(env_dut_str)
        if dispatcher_cfg:
            with open(os.path.join(output_dir, "dispatcher.yaml"),
                      "w") as f_out:
                f_out.write(dispatcher_cfg)

    def start_job(self, job):
        # Load job definition to get the variables for template
        # rendering
        job_def = yaml.safe_load(job.definition)
        job_ctx = job_def.get('context', {})

        device = job.actual_device
        worker = device.worker_host

        # TODO: check that device_cfg is not None!
        device_cfg = device.load_configuration(job_ctx)

        # Try to load the dispatcher specific files and then fallback to the
        # default configuration files.
        env_str = load_optional_yaml_file(
            os.path.join(DISPATCHERS_PATH, worker.hostname, "env.yaml"),
            ENV_PATH)
        env_dut_str = load_optional_yaml_file(
            os.path.join(DISPATCHERS_PATH, worker.hostname, "env.dut.yaml"),
            ENV_DUT_PATH)
        dispatcher_cfg = load_optional_yaml_file(
            os.path.join(DISPATCHERS_PATH, worker.hostname, "dispatcher.yaml"),
            os.path.join(DISPATCHERS_PATH, "%s.yaml" % worker.hostname))

        self.save_job_config(job, device_cfg, env_str, env_dut_str,
                             dispatcher_cfg)
        self.logger.info("[%d] START => %s (%s)", job.id, worker.hostname,
                         device.hostname)
        send_multipart_u(self.controler, [
            worker.hostname, 'START',
            str(job.id),
            self.export_definition(job),
            yaml.dump(device_cfg), dispatcher_cfg, env_str, env_dut_str
        ])

        # For multinode jobs, start the dynamic connections
        parent = job
        for sub_job in job.sub_jobs_list:
            if sub_job == parent or not sub_job.dynamic_connection:
                continue

            # inherit only enough configuration for dynamic_connection operation
            self.logger.info(
                "[%d] Trimming dynamic connection device configuration.",
                sub_job.id)
            min_device_cfg = parent.actual_device.minimise_configuration(
                device_cfg)

            self.save_job_config(sub_job, min_device_cfg, env_str, env_dut_str,
                                 dispatcher_cfg)
            self.logger.info("[%d] START => %s (connection)", sub_job.id,
                             worker.hostname)
            send_multipart_u(self.controler, [
                worker.hostname, 'START',
                str(sub_job.id),
                self.export_definition(sub_job),
                yaml.dump(min_device_cfg), dispatcher_cfg, env_str, env_dut_str
            ])

    def start_jobs(self, jobs=None):
        """
        Loop on all scheduled jobs and send the START message to the slave.
        """
        # make the request atomic
        query = TestJob.objects.select_for_update()
        # Only select test job that are ready
        query = query.filter(state=TestJob.STATE_SCHEDULED)
        # Only start jobs on online workers
        query = query.filter(
            actual_device__worker_host__state=Worker.STATE_ONLINE)
        # exclude test job without a device: they are special test jobs like
        # dynamic connection.
        query = query.exclude(actual_device=None)
        # Allow for partial scheduling
        if jobs is not None:
            query = query.filter(id__in=jobs)

        # Loop on all jobs
        for job in query:
            msg = None
            try:
                self.start_job(job)
            except jinja2.TemplateNotFound as exc:
                self.logger.error("[%d] Template not found: '%s'", job.id,
                                  exc.message)
                msg = "Template not found: '%s'" % exc.message
            except jinja2.TemplateSyntaxError as exc:
                self.logger.error(
                    "[%d] Template syntax error in '%s', line %d: %s", job.id,
                    exc.name, exc.lineno, exc.message)
                msg = "Template syntax error in '%s', line %d: %s" % (
                    exc.name, exc.lineno, exc.message)
            except OSError as exc:
                self.logger.error("[%d] Unable to read '%s': %s", job.id,
                                  exc.filename, exc.strerror)
                msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror)
            except yaml.YAMLError as exc:
                self.logger.error("[%d] Unable to parse job definition: %s",
                                  job.id, exc)
                msg = "Cannot parse job definition: %s" % exc

            if msg:
                # Add the error as lava.job result
                metadata = {
                    "case": "job",
                    "definition": "lava",
                    "error_type": "Infrastructure",
                    "error_msg": msg,
                    "result": "fail"
                }
                suite, _ = TestSuite.objects.get_or_create(name="lava",
                                                           job=job)
                TestCase.objects.create(name="job",
                                        suite=suite,
                                        result=TestCase.RESULT_FAIL,
                                        metadata=yaml.dump(metadata))
                job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True)
                job.save()

    def cancel_jobs(self, partial=False):
        # make the request atomic
        query = TestJob.objects.select_for_update()
        # Only select the test job that are canceling
        query = query.filter(state=TestJob.STATE_CANCELING)
        # Only cancel jobs on online workers
        query = query.filter(
            actual_device__worker_host__state=Worker.STATE_ONLINE)

        # Allow for partial canceling
        if partial:
            query = query.filter(id__in=list(self.events["canceling"]))

        # Loop on all jobs
        for job in query:
            worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host
            self.logger.info("[%d] CANCEL => %s", job.id, worker.hostname)
            send_multipart_u(self.controler,
                             [worker.hostname, 'CANCEL',
                              str(job.id)])

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-master", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        filename = os.path.join(settings.MEDIA_ROOT, 'lava-master-config.yaml')
        self.logger.debug("[INIT] Dumping config to %s", filename)
        with open(filename, 'w') as output:
            yaml.dump(options, output)

        self.logger.info("[INIT] Marking all workers as offline")
        with transaction.atomic():
            for worker in Worker.objects.select_for_update().all():
                worker.go_state_offline()
                worker.save()

        # Create the sockets
        context = zmq.Context()
        self.controler = context.socket(zmq.ROUTER)
        self.event_socket = context.socket(zmq.SUB)

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.controler.setsockopt(zmq.IPV6, 1)
            self.event_socket.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s",
                                  options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(
                    options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s",
                                  options['slaves_certs'])
                self.auth.configure_curve(domain='*',
                                          location=options['slaves_certs'])
            except OSError as err:
                self.logger.error(err)
                self.auth.stop()
                return
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_server = True

            self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
            self.inotify_fd = watch_directory(options["slaves_certs"])
            if self.inotify_fd is None:
                self.logger.error("[INIT] Unable to start inotify")

        self.controler.setsockopt(zmq.IDENTITY, b"master")
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42
        # "If two clients use the same identity when connecting to a ROUTER
        # [...] the ROUTER socket shall hand-over the connection to the new
        # client and disconnect the existing one."
        self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1)
        self.controler.bind(options['master_socket'])

        self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC))
        self.event_socket.connect(options['event_url'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.controler, zmq.POLLIN)
        self.poller.register(self.event_socket, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] LAVA master has started.")
        self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION)

        try:
            self.main_loop(options)
        except BaseException as exc:
            self.logger.error("[CLOSE] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Drop controler socket: the protocol does handle lost messages
            self.logger.info(
                "[CLOSE] Closing the controler socket and dropping messages")
            self.controler.close(linger=0)
            self.event_socket.close(linger=0)
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def main_loop(self, options):
        last_schedule = last_dispatcher_check = time.time()

        while True:
            try:
                try:
                    # Compute the timeout
                    now = time.time()
                    timeout = min(
                        SCHEDULE_INTERVAL - (now - last_schedule),
                        PING_INTERVAL - (now - last_dispatcher_check))
                    # If some actions are remaining, decrease the timeout
                    if any([self.events[k] for k in self.events.keys()]):
                        timeout = min(timeout, 2)
                    # Wait at least for 1ms
                    timeout = max(timeout * 1000, 1)

                    # Wait for data or a timeout
                    sockets = dict(self.poller.poll(timeout))
                except zmq.error.ZMQError:
                    continue

                if sockets.get(self.pipe_r) == zmq.POLLIN:
                    self.logger.info("[POLL] Received a signal, leaving")
                    break

                # Command socket
                if sockets.get(self.controler) == zmq.POLLIN:
                    while self.controler_socket(
                    ):  # Unqueue all pending messages
                        pass

                # Events socket
                if sockets.get(self.event_socket) == zmq.POLLIN:
                    while self.read_event_socket(
                    ):  # Unqueue all pending messages
                        pass
                    # Wait for the next iteration to handle the event.
                    # In fact, the code that generated the event (lava-logs or
                    # lava-server-gunicorn) needs some time to commit the
                    # database transaction.
                    # If we are too fast, the database object won't be
                    # available (or in the right state) yet.
                    continue

                # Inotify socket
                if sockets.get(self.inotify_fd) == zmq.POLLIN:
                    os.read(self.inotify_fd, 4096)
                    self.logger.debug("[AUTH] Reloading certificates from %s",
                                      options['slaves_certs'])
                    self.auth.configure_curve(domain='*',
                                              location=options['slaves_certs'])

                # Check dispatchers status
                now = time.time()
                if now - last_dispatcher_check > PING_INTERVAL:
                    for hostname, dispatcher in self.dispatchers.items():
                        if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT:
                            if hostname == "lava-logs":
                                self.logger.error(
                                    "[STATE] lava-logs goes OFFLINE")
                            else:
                                self.logger.error(
                                    "[STATE] Dispatcher <%s> goes OFFLINE",
                                    hostname)
                            self.dispatchers[hostname].go_offline()
                    last_dispatcher_check = now

                # Limit accesses to the database. This will also limit the rate of
                # CANCEL and START messages
                if time.time() - last_schedule > SCHEDULE_INTERVAL:
                    if self.dispatchers["lava-logs"].online:
                        schedule(self.logger)

                        # Dispatch scheduled jobs
                        with transaction.atomic():
                            self.start_jobs()
                    else:
                        self.logger.warning(
                            "lava-logs is offline: can't schedule jobs")

                    # Handle canceling jobs
                    with transaction.atomic():
                        self.cancel_jobs()

                    # Do not count the time taken to schedule jobs
                    last_schedule = time.time()
                else:
                    # Cancel the jobs and remove the jobs from the set
                    if self.events["canceling"]:
                        with transaction.atomic():
                            self.cancel_jobs(partial=True)
                        self.events["canceling"] = set()
                    # Schedule for available device-types
                    if self.events["available_dt"]:
                        jobs = schedule(self.logger,
                                        self.events["available_dt"])
                        self.events["available_dt"] = set()
                        # Dispatch scheduled jobs
                        with transaction.atomic():
                            self.start_jobs(jobs)

            except (OperationalError, InterfaceError):
                self.logger.info("[RESET] database connection reset.")
                # Closing the database connection will force Django to reopen
                # the connection
                connection.close()
                time.sleep(2)
Exemple #30
0
class StupidNode:
    pubkey = privkey = None
    channel = ""  # subscription filter or something (I think)
    PORTS = 4  # as we add or remove ports, make sure this is the number of ports a StupidNode uses

    def __init__(self, endpoint="*", identity=None, keyring=DEFAULT_KEYRING):
        self.keyring = keyring
        self.endpoint = (endpoint if isinstance(endpoint, Endpoint) else
                         Endpoint(endpoint))
        self.endpoints = list()
        self.identity = identity or f"{gethostname()}-{self.endpoint.pub}"
        self.log = logging.getLogger(f"{self.identity}")

        self.log.debug("begin node setup / creating context")

        self.ctx = zmq.Context()
        self.cleartext_ctx = zmq.Context()

        self.start_auth()

        self.log.debug("creating sockets")

        self.pub = self.mk_socket(zmq.PUB)
        self.router = self.mk_socket(zmq.ROUTER)
        self.router.router_mandatory = (
            1  # one of the few opts that can be set after bind()
        )
        self.rep = self.mk_socket(zmq.REP, enable_curve=False)

        self.sub = list()
        self.dealer = list()

        self.log.debug("binding sockets")

        self.bind(self.pub)
        self.bind(self.router)
        self.bind(self.rep, enable_curve=False)

        self.log.debug("registering polling")

        self.poller = zmq.Poller()
        self.poller.register(self.router, zmq.POLLIN)

        self.log.debug("configuring interrupt signal")
        signal.signal(signal.SIGINT, self.interrupt)

        self.log.debug("configuring WAI Reply Thread")
        self._who_are_you_thread = Thread(
            target=self.who_are_you_reply_machine)
        self._who_are_you_continue = True
        self._who_are_you_thread.start()

        self.route_queue = deque(list(), ROUTE_QUEUE_LEN)
        self.routes = dict()

        self.log.debug("node setup complete")

    def who_are_you_reply_machine(self):
        while self._who_are_you_continue:
            if self.rep.poll(200):
                self.log.debug("wai polled, trying to recv")
                msg = self.rep.recv()
                ttype = zmq_socket_type_name(self.rep)
                self.log.debug('received "%s" over %s socket', msg, ttype)
                msg = [self.identity.encode(), self.pubkey]
                self.log.debug('sending "%s" as reply over %s socket', msg,
                               ttype)
                self.rep.send_multipart(msg)
        self.log.debug("wai thread seems finished, loop broken")

    def start_auth(self):
        self.log.debug("starting auth thread")
        self.auth = ThreadAuthenticator(self.ctx)
        self.auth.start()
        self.auth.allow("127.0.0.1")
        self.auth.configure_curve(domain="*", location=self.keyring)
        self.load_or_create_key()

    @property
    def key_basename(self):
        return scrub_identity_name_for_certfile(self.identity)

    @property
    def key_filename(self):
        return os.path.join(self.keyring, self.key_basename + ".key")

    @property
    def secret_key_filename(self):
        return self.key_filename + "_secret"

    def load_key(self):
        self.log.debug("loading node key-pair")
        self.pubkey, self.privkey = zmq.auth.load_certificate(
            self.secret_key_filename)

    def load_or_create_key(self):
        try:
            self.load_key()
        except IOError as e:
            self.log.debug("error loading key: %s", e)
            self.log.debug("creating node key-pair")
            os.makedirs(self.keyring, mode=0o0700, exist_ok=True)
            zmq.auth.create_certificates(self.keyring, self.key_basename)
            self.load_key()

    def preprocess_message(self, msg, msg_class=TaggedMessage):
        if not isinstance(msg, msg_class):
            if not isinstance(msg, (list, tuple)):
                msg = (msg, )
            msg = msg_class(*msg, name=self.identity)
        rmsg = repr(msg)
        emsg = msg.encode()
        return msg, rmsg, emsg

    def route_failed(self, msg):
        if not isinstance(msg, RoutedMessage):
            raise TypeError("msg must already be a RoutedMessage")
        msg.failures += 1
        if msg.failures <= 5:
            self.log.debug("(re)queueing %s for later delivery", repr(msg))
            if len(self.route_queue) == self.route_queue.maxlen:
                self.log.error("route_queue full, discarding %s",
                               repr(self.route_queue[0]))
            self.route_queue.append(msg)
        else:
            self.log.error("discarding %s after %d failures", repr(msg),
                           msg.failures)

    def route_message(self, to, msg):
        if isinstance(to, StupidNode):
            to = to.identity
        if isinstance(to, (list, tuple)):
            to = to[-1]
        R = self.routes.get(to)
        if R:
            to = (R[0], to)
        if isinstance(msg, RoutedMessage):
            msg.to = to
        else:
            # preprocess passes *msg to msg_class() -- ie, RoutedMessage(to, *msg)
            if isinstance(msg, list):
                msg = tuple(msg)
            elif not isinstance(msg, tuple):
                msg = (msg, )
            msg = (to, ) + msg
        tmsg, rmsg, emsg = self.preprocess_message(msg,
                                                   msg_class=RoutedMessage)
        self.log.debug("routing message %s -- encoding: %s", rmsg, emsg)
        try:
            self.router.send_multipart(emsg)
        except zmq.error.ZMQError as zmq_e:
            self.log.debug("route to %s failed: %s", to, zmq_e)
            if "Host unreachable" not in str(zmq_e):
                raise
            self.route_failed(tmsg)

    def deal_message(self, msg):
        self.log.debug(
            "dealing message (actually publishing with no_publish=True)")
        self.publish_message(msg, no_publish=True)

    def publish_message(self,
                        msg,
                        no_deal=False,
                        no_deal_to=None,
                        no_publish=False):
        tmsg, rmsg, emsg = self.preprocess_message(msg)
        self.log.debug(
            "publishing message %s no_publish=%s, no_deal=%s, no_deal_to=%s",
            rmsg,
            no_publish,
            no_deal,
            no_deal_to,
        )
        self.local_workflow(tmsg)
        if not no_publish:
            self.pub.send_multipart(emsg)
        if no_deal:
            return
        if no_deal_to is None:
            ok_send = lambda x: True
        elif callable(no_deal_to):
            ok_send = no_deal_to
        elif isinstance(no_deal_to, zmq.Socket):
            npt_i = self.dealer.index(no_deal_to)
            ok_send = lambda x: x != npt_i
        elif isinstance(no_deal_to, int):
            ok_send = lambda x: x != no_deal_to
        elif isinstance(no_deal_to, (list, tuple)):
            ok_send = lambda x: x not in no_deal_to
        for i, sock in enumerate(self.dealer):
            if ok_send(i):
                self.log.debug("dealing message %s to %s", rmsg,
                               self.endpoints[i])
                sock.send_multipart(emsg)
            else:
                self.log.debug("not sending %s to %s", rmsg, self.endpoints[i])

    def mk_socket(self, stype, enable_curve=True):
        # defaults:
        # socket.setsockopt(zmq.LINGER, -1) # infinite
        # socket.setsockopt(zmq.IDENTITY, None)
        # socket.setsockopt(zmq.TCP_KEEPALIVE, -1)
        # socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, -1)
        # socket.setsockopt(zmq.TCP_KEEPALIVE_CNT, -1)
        # socket.setsockopt(zmq.TCP_KEEPALIVE_IDLE, -1)
        # socket.setsockopt(zmq.RECONNECT_IVL, 100)
        # socket.setsockopt(zmq.RECONNECT_IVL_MAX, 0) # 0 := always use IVL

        # the above can be accessed as attributes instead (they are case
        # insensitive, we choose lower case below so it looks like boring
        # python)

        if enable_curve:
            socket = self.ctx.socket(stype)
            self.log.debug("create %s socket in crypto context",
                           zmq_socket_type_name(stype))
        else:
            socket = self.cleartext_ctx.socket(stype)
            self.log.debug("create %s socket in cleartext context",
                           zmq_socket_type_name(stype))

        socket.linger = 1
        socket.identity = self.identity.encode()
        socket.reconnect_ivl = 1000
        socket.reconnect_ivl_max = 10000

        if enable_curve:
            socket.curve_secretkey = self.privkey
            socket.curve_publickey = self.pubkey

        return socket

    def local_workflow(self, msg):
        self.log.debug("start local_workflow %s", repr(msg))
        msg = self.local_react(msg)
        if msg:
            msg = self.all_react(msg)
        return msg

    def sub_workflow(self, socket):
        idx = self.sub.index(socket)
        enp = self.endpoints[idx]
        msg = self.sub_receive(socket, idx)
        self.log.debug("start sub_workflow (idx=%d -> endpoint=%s) %s", idx,
                       enp, repr(msg))
        for react in (self.sub_react, self.nonlocal_react, self.all_react):
            if msg:
                msg = react(msg, idx=idx)
        self.log.debug("end sub_workflow")
        return msg

    def router_workflow(self):
        msg = self.router_receive()
        self.log.debug("start router_workflow %s", repr(msg))
        for react in (self.router_react, self.nonlocal_react, self.all_react):
            if not msg:
                break
            msg = react(msg)
        self.log.debug("end router_workflow")
        return msg

    def dealer_workflow(self, socket):
        idx = self.dealer.index(socket)
        enp = self.endpoints[idx]
        msg = self.dealer_receive(socket, idx)
        self.log.debug("start deal_workflow (idx=%d -> endpoint=%s) %s", idx,
                       enp, repr(msg))
        for react in (self.dealer_react, self.nonlocal_react, self.all_react):
            if not msg:
                break
            msg = react(msg, idx=idx)
        self.log.debug("end deal_workflow")
        return msg

    def sub_receive(self, socket, idx):  # pylint: disable=unused-argument
        return TaggedMessage(*socket.recv_multipart())

    def dealer_receive(self, socket, idx):  # pylint: disable=unused-argument
        msg = socket.recv_multipart()
        rm = RoutedMessage.decode(msg)
        if rm:
            return rm
        # dealer's always receive a routed message if it doesn't appear to be
        # routed, then it's simply intended for us. In that case, build a
        # tagged message and mark it as non-publish
        msg = TaggedMessage(*msg)
        msg.publish_mark = False
        return msg

    def router_receive(self):
        # we ignore the source ID (in '_') and just believe the msg.tag.name ... it's
        # roughly the same thing anyway
        _, *msg = self.router.recv_multipart()
        rm = RoutedMessage.decode(msg)
        if rm:
            return rm
        return TaggedMessage(*msg)

    def all_react(self, msg, idx=None):  # pylint: disable=unused-argument
        return msg

    def sub_react(self, msg, idx=None):  # pylint: disable=unused-argument
        return msg

    def dealer_react(self, msg, idx=None):  # pylint: disable=unused-argument
        return msg

    def router_react(self, msg):
        return msg

    def nonlocal_react(self, msg, idx=None):
        if isinstance(msg, RoutedMessage):
            msg = self.routed_react(msg, idx=idx)
        return msg

    def local_react(self, msg):
        return msg

    def routed_react(self, msg, idx=None):  # pylint: disable=unused-argument
        return False

    def poll(self, timeo=500, other_cb=None):
        """Check to see if there's any incoming messages. If anything seems ready to receive,
        invoke the related workflow or invoke other_cb (if given) on the socket item.
        """
        items = dict(self.poller.poll(timeo))
        ret = list()
        for item in items:
            if items[item] != zmq.POLLIN:
                continue
            if item in self.sub:
                res = self.sub_workflow(item)
            elif item in self.dealer:
                res = self.dealer_workflow(item)
            elif item is self.router:
                res = self.router_workflow()
            elif callable(other_cb):
                res = other_cb(item)
            else:
                res = None
                if False and isinstance(item, zmq.Socket):
                    self.log.error(
                        "no workflow defined for socket of type %s -- received: %s",
                        zmq_socket_type_name(item),
                        item.recv_multipart(),
                    )
                else:
                    self.log.error(
                        "no workflow defined for socket of type %s -- regarding as fatal",
                        zmq_socket_type_name(item),
                    )
                    # note: this normally doesn't trigger an exit... thanks threading
                    raise Exception("unhandled poll item")
            if isinstance(res, TaggedMessage):
                ret.append(res)
        return ret

    def interrupt(self, signo, eframe):  # pylint: disable=unused-argument
        print(" kaboom")
        self.closekill()
        sys.exit(0)

    def closekill(self):
        if hasattr(self, "auth") and self.auth is not None:
            if self.auth.is_alive():
                self.log.debug("trying to stop auth thread")
                self.auth.stop()
                self.log.debug("auth thread seems to have stopped")
            del self.auth

        if hasattr(self, "_who_are_you_thread"):
            if self._who_are_you_thread.is_alive():
                self.log.debug("WAI Thread seems to be alive, trying to join")
                self._who_are_you_continue = False
                self._who_are_you_thread.join()
                self.log.debug("WAI Thread seems to jave joined us.")
            del self._who_are_you_thread

        if hasattr(self, "cleartext_ctx"):
            self.log.debug("destroying cleartext context")
            self.cleartext_ctx.destroy(1)
            del self.cleartext_ctx

        if hasattr(self, "ctx"):
            self.log.debug("destroying crypto context")
            self.ctx.destroy(1)
            del self.ctx

    def __del__(self):
        self.log.debug("%s is being deleted", self)
        self.closekill()

    def bind(self, socket, enable_curve=True):
        if enable_curve:
            socket.curve_server = True  # must come before bind
        try:
            f = self.endpoint.format(socket.type)
            socket.bind(f)
        except zmq.ZMQError as e:
            raise zmq.ZMQError(f"unable to bind {f}: {e}") from e

    def who_are_you_request(self, endpoint):
        req = self.mk_socket(zmq.REQ, enable_curve=False)
        req.connect(endpoint.format(zmq.REQ))
        msg = b"Who are you?"
        self.log.debug("sending cleartext request: %s", msg)
        req.send(msg)
        self.log.debug("waiting for reply")
        res = req.recv_multipart()
        self.log.debug("received reply: %s", res)
        if len(res) == 2:
            return res
        req.close()
        return None, None

    def pubkey_pathname(self, node_id):
        if isinstance(node_id, Endpoint):
            node_id = Endpoint.host
        fname = scrub_identity_name_for_certfile(node_id) + ".key"
        pname = os.path.join(self.keyring, fname)
        return pname

    def learn_or_load_endpoint_pubkey(self, endpoint):
        epubk_pname = self.pubkey_pathname(endpoint)
        if not os.path.isfile(epubk_pname):
            self.log.debug(
                "%s does not exist yet, trying to learn certificate",
                epubk_pname)
            node_id, public_key = self.who_are_you_request(endpoint)
            if node_id:
                endpoint.identity = node_id.decode()
                epubk_pname = self.pubkey_pathname(node_id)
                if not os.path.isfile(epubk_pname):
                    with open(epubk_pname, "wb") as fh:
                        fh.write(
                            b"# generated via rep/req pubkey transfer\n\n")
                        fh.write(b"metadata\n")
                        # NOTE: in zmq/auth/certs.py's _write_key_file,
                        # metadata should be key-value pairs; roughly like the
                        # following (although with their particular py2/py3
                        # nerosis edited out):
                        #
                        # f.write('metadata\n')
                        #     for k,v in metadata.items():
                        #         f.write(f"    {k} = {v}\n")
                        fh.write(b"curve\n")
                        fh.write(b'    public-key = "')
                        fh.write(public_key)
                        fh.write(b'"')
        self.log.debug("loading certificate %s", epubk_pname)
        ret, _ = zmq.auth.load_certificate(epubk_pname)
        return ret

    def connect_to_endpoints(self, *endpoints):
        self.log.debug("connecting remote endpoints")
        for item in endpoints:
            self.connect_to_endpoint(item)
        self.log.debug("remote endpoints connected")
        return self

    def _create_connected_socket(self,
                                 endpoint,
                                 stype,
                                 pubkey,
                                 preconnect=None):
        self.log.debug("creating %s socket to endpoint=%s",
                       zmq_socket_type_name(stype), endpoint)
        s = self.mk_socket(stype)
        s.curve_serverkey = pubkey
        if callable(preconnect):
            preconnect(s)
        s.connect(endpoint.format(stype))
        return s

    def connect_to_endpoint(self, endpoint):
        if isinstance(endpoint, StupidNode):
            endpoint = endpoint.endpoint

        elif not isinstance(endpoint, Endpoint):
            endpoint = Endpoint(endpoint)

        self.log.debug("learning or loading endpoint=%s pubkey", endpoint)
        epk = self.learn_or_load_endpoint_pubkey(endpoint)

        sos = lambda s: s.setsockopt_string(zmq.SUBSCRIBE, self.channel)
        sub = self._create_connected_socket(endpoint, zmq.SUB, epk, sos)
        self.poller.register(sub, zmq.POLLIN)
        self.sub.append(sub)

        deal = self._create_connected_socket(endpoint, zmq.DEALER, epk)
        self.poller.register(deal, zmq.POLLIN)
        self.dealer.append(deal)

        self.endpoints.append(endpoint)

        return self

    def __repr__(self):
        return f"{self.__class__.__name__}({self.identity})"