예제 #1
0
class ZMQPull(ZMQ):
    classname = "ZMQPull"

    def __init__(self, name, options, inbound):
        super().__init__(name, options, inbound)
        self.socket_type = zmq.PULL

    def secure_setup(self):
        # Load certificates
        # TODO: handle errors
        self.auth = ThreadAuthenticator(self.context)
        self.auth.start()
        self.LOG.debug("Server keys in %s", self.secure_config["self"])
        sock_pub, sock_priv = load_certificate(self.secure_config["self"])
        if self.secure_config.get("clients", None) is not None:
            self.LOG.debug("Client certificates in %s",
                           self.secure_config["clients"])
            self.auth.configure_curve(domain="*",
                                      location=self.secure_config["clients"])
        else:
            self.LOG.debug("Every clients can connect")
            self.auth.configure_curve(domain="*",
                                      location=zmq.auth.CURVE_ALLOW_ANY)

        # Setup the socket
        self.sock.curve_publickey = sock_pub
        self.sock.curve_secretkey = sock_priv
        self.sock.curve_server = True
예제 #2
0
    def _init_txzmq(self):
        """
        Configure the txzmq components and connection.
        """
        self._zmq_factory = txzmq.ZmqFactory()
        self._zmq_factory.registerForShutdown()
        self._zmq_connection = txzmq.ZmqREPConnection(self._zmq_factory)

        context = self._zmq_factory.context
        socket = self._zmq_connection.socket

        def _gotMessage(messageId, messageParts):
            self._zmq_connection.reply(messageId, "OK")
            self._process_request(messageParts)

        self._zmq_connection.gotMessage = _gotMessage

        if flags.ZMQ_HAS_CURVE:
            # Start an authenticator for this context.
            auth = ThreadAuthenticator(context)
            auth.start()
            # XXX do not hardcode this here.
            auth.allow('127.0.0.1')

            # Tell authenticator to use the certificate in a directory
            auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
            public, secret = get_backend_certificates()
            socket.curve_publickey = public
            socket.curve_secretkey = secret
            socket.curve_server = True  # must come before bind

        proto, addr = self._server_address.split('://')  # tcp/ipc, ip/socket
        socket.bind(self._server_address)
        if proto == 'ipc':
            os.chmod(addr, 0600)
예제 #3
0
def main():
    localhost = socket_m.getfqdn()

    port = "5556"
    # ip = "*"
    ip = socket_m.gethostbyaddr(localhost)[2][0]

    context = zmq.Context()
    socket = context.socket(zmq.PULL)
    socket.zap_domain = b'global'
    socket.bind("tcp://" + ip + ":%s" % port)

    auth = ThreadAuthenticator(context)

    host = localhost
    # host = asap3-p00
    whitelist = socket_m.gethostbyaddr(host)[2][0]
    # whitelist = None
    auth.start()

    if whitelist is None:
        auth.auth = None
    else:
        auth.allow(whitelist)

    try:
        while True:
            message = socket.recv_multipart()
            print("received reply ", message)
    except KeyboardInterrupt:
        pass
    finally:
        auth.stop()
예제 #4
0
파일: inputs.py 프로젝트: ivoire/ReactOBus
class ZMQPull(ZMQ):
    classname = "ZMQPull"

    def __init__(self, name, options, inbound):
        super().__init__(name, options, inbound)
        self.socket_type = zmq.PULL

    def secure_setup(self):
        # Load certificates
        # TODO: handle errors
        self.auth = ThreadAuthenticator(self.context)
        self.auth.start()
        self.LOG.debug("Server keys in %s", self.secure_config["self"])
        sock_pub, sock_priv = zmq.auth.load_certificate(self.secure_config["self"])
        if self.secure_config.get("clients", None) is not None:
            self.LOG.debug("Client certificates in %s", self.secure_config["clients"])
            self.auth.configure_curve(domain="*", location=self.secure_config["clients"])
        else:
            self.LOG.debug("Every clients can connect")
            self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

        # Setup the socket
        self.sock.curve_publickey = sock_pub
        self.sock.curve_secretkey = sock_priv
        self.sock.curve_server = True
예제 #5
0
파일: security.py 프로젝트: OTL/jps
class Authenticator(object):
    _authenticators = {}

    @classmethod
    def instance(cls, public_keys_dir):
        '''Please avoid create multi instance'''
        if public_keys_dir in cls._authenticators:
            return cls._authenticators[public_keys_dir]
        new_instance = cls(public_keys_dir)
        cls._authenticators[public_keys_dir] = new_instance
        return new_instance

    def __init__(self, public_keys_dir):
        self._auth = ThreadAuthenticator(zmq.Context.instance())
        self._auth.start()
        self._auth.allow('*')
        self._auth.configure_curve(domain='*', location=public_keys_dir)

    def set_server_key(self, zmq_socket, server_secret_key_path):
        '''must call before bind'''
        load_and_set_key(zmq_socket, server_secret_key_path)
        zmq_socket.curve_server = True

    def set_client_key(self, zmq_socket, client_secret_key_path, server_public_key_path):
        '''must call before bind'''
        load_and_set_key(zmq_socket, client_secret_key_path)
        server_public, _ = zmq.auth.load_certificate(server_public_key_path)
        zmq_socket.curve_serverkey = server_public

    def stop(self):
        self._auth.stop()
    def _run(self):
        """
        Start a loop to process the ZMQ requests from the signaler client.
        """
        logger.debug("Running SignalerQt loop")
        context = zmq.Context()
        socket = context.socket(zmq.REP)

        # Start an authenticator for this context.
        auth = ThreadAuthenticator(context)
        auth.start()
        auth.allow('127.0.0.1')

        # Tell authenticator to use the certificate in a directory
        auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
        public, secret = get_frontend_certificates()
        socket.curve_publickey = public
        socket.curve_secretkey = secret
        socket.curve_server = True  # must come before bind

        socket.bind(self.BIND_ADDR)

        while self._do_work.is_set():
            # Wait for next request from client
            try:
                request = socket.recv(zmq.NOBLOCK)
                logger.debug("Received request: '{0}'".format(request))
                socket.send("OK")
                self._process_request(request)
            except zmq.ZMQError as e:
                if e.errno != zmq.EAGAIN:
                    raise
            time.sleep(0.01)

        logger.debug("SignalerQt thread stopped.")
예제 #7
0
    def _run(self):
        """
        Start a loop to process the ZMQ requests from the signaler client.
        """
        logger.debug("Running SignalerQt loop")
        context = zmq.Context()
        socket = context.socket(zmq.REP)

        # Start an authenticator for this context.
        auth = ThreadAuthenticator(context)
        auth.start()
        auth.allow('127.0.0.1')

        # Tell authenticator to use the certificate in a directory
        auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
        public, secret = get_frontend_certificates()
        socket.curve_publickey = public
        socket.curve_secretkey = secret
        socket.curve_server = True  # must come before bind

        socket.bind(self.BIND_ADDR)

        while self._do_work.is_set():
            # Wait for next request from client
            try:
                request = socket.recv(zmq.NOBLOCK)
                # logger.debug("Received request: '{0}'".format(request))
                socket.send("OK")
                self._process_request(request)
            except zmq.ZMQError as e:
                if e.errno != zmq.EAGAIN:
                    raise
            time.sleep(0.01)

        logger.debug("SignalerQt thread stopped.")
예제 #8
0
    def _init_zmq(self):
        """
        Configure the zmq components and connection.
        """
        context = zmq.Context()
        socket = context.socket(zmq.REP)

        if flags.ZMQ_HAS_CURVE:
            # Start an authenticator for this context.
            auth = ThreadAuthenticator(context)
            auth.start()
            # XXX do not hardcode this here.
            auth.allow('127.0.0.1')

            # Tell authenticator to use the certificate in a directory
            auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
            public, secret = get_backend_certificates()
            socket.curve_publickey = public
            socket.curve_secretkey = secret
            socket.curve_server = True  # must come before bind

        socket.bind(self.BIND_ADDR)
        if not flags.ZMQ_HAS_CURVE:
            os.chmod(self.SOCKET_FILE, 0600)

        self._zmq_socket = socket
예제 #9
0
def run_mdp_broker():
    args = docopt("""Usage:
        mdp-broker [options] <config>

    Options:
        -h --help                 show this help message and exit
        -s --secure               generate (and print) client & broker keys for a secure server
    """)
    global log
    _setup_logging(args['<config>'])

    log = logging.getLogger(__name__)

    cp = ConfigParser()
    cp.read(args['<config>'])

    # Parse settings a bit
    raw = dict(
        (option, cp.get('mdp-broker', option))
        for option in cp.options('mdp-broker'))
    s = SettingsSchema().to_python(raw)

    if args['--secure']:
        broker_key = Key.generate()
        client_key = Key.generate()
        s['key'] = dict(
            broker=broker_key,
            client=client_key)
        log.info('Auto-generated keys: %s_%s_%s',
            broker_key.public, client_key.public, client_key.secret)
        log.info(' broker.public: %s', broker_key.public)
        log.info(' client.public: %s', client_key.public)
        log.info(' client.secret: %s', client_key.secret)

    if s['key']:
        log.info('Starting secure mdp-broker on %s', s['uri'])
        auth = ThreadAuthenticator()
        auth.start()
        auth.thread.authenticator.certs['*'] = {
            s['key']['client'].public: 'OK'}

        broker = SecureMajorDomoBroker(s['key']['broker'], s['uri'])
    else:
        log.info('Starting mdp-broker on %s', s['uri'])
        broker = MajorDomoBroker(s['uri'])
    try:
        broker.serve_forever()
    except:
        auth.stop()
        raise
def main():
    auth = ThreadAuthenticator(zmq.Context.instance())
    auth.start()
    auth.allow('127.0.0.1')
    # Tell the authenticator how to handle CURVE requests
    auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)

    key = Key.load('example/broker.key_secret')
    broker = SecureMajorDomoBroker(key, sys.argv[1])
    try:
        broker.serve_forever()
    except KeyboardInterrupt:
        auth.stop()
        raise
예제 #11
0
def main():
    auth = ThreadAuthenticator(zmq.Context.instance())
    auth.start()
    auth.allow('127.0.0.1')
    # Tell the authenticator how to handle CURVE requests
    auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)

    key = Key.load('example/broker.key_secret')
    broker = SecureMajorDomoBroker(key, sys.argv[1])
    try:
        broker.serve_forever()
    except KeyboardInterrupt:
        auth.stop()
        raise
예제 #12
0
    def _start_thread_auth(self, socket):
        """
        Start the zmq curve thread authenticator.

        :param socket: The socket in which to configure the authenticator.
        :type socket: zmq.Socket
        """
        authenticator = ThreadAuthenticator(self._factory.context)
        authenticator.start()
        # XXX do not hardcode this here.
        authenticator.allow('127.0.0.1')
        # tell authenticator to use the certificate in a directory
        public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
        authenticator.configure_curve(domain="*", location=public_keys_dir)
        socket.curve_server = True  # must come before bind
예제 #13
0
    def _start_thread_auth(self, socket):
        """
        Start the zmq curve thread authenticator.

        :param socket: The socket in which to configure the authenticator.
        :type socket: zmq.Socket
        """
        authenticator = ThreadAuthenticator(self._factory.context)
        authenticator.start()
        # XXX do not hardcode this here.
        authenticator.allow('127.0.0.1')
        # tell authenticator to use the certificate in a directory
        public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
        authenticator.configure_curve(domain="*", location=public_keys_dir)
        socket.curve_server = True  # must come before bind
예제 #14
0
def run_mdp_broker():
    args = docopt("""Usage:
        mdp-broker [options] <config>

    Options:
        -h --help                 show this help message and exit
        -s --secure               generate (and print) client & broker keys for a secure server
    """)
    global log
    _setup_logging(args['<config>'])

    log = logging.getLogger(__name__)

    cp = ConfigParser()
    cp.read(args['<config>'])

    # Parse settings a bit
    raw = dict((option, cp.get('mdp-broker', option))
               for option in cp.options('mdp-broker'))
    s = SettingsSchema().to_python(raw)

    if args['--secure']:
        broker_key = Key.generate()
        client_key = Key.generate()
        s['key'] = dict(broker=broker_key, client=client_key)
        log.info('Auto-generated keys: %s_%s_%s', broker_key.public,
                 client_key.public, client_key.secret)
        log.info(' broker.public: %s', broker_key.public)
        log.info(' client.public: %s', client_key.public)
        log.info(' client.secret: %s', client_key.secret)

    if s['key']:
        log.info('Starting secure mdp-broker on %s', s['uri'])
        auth = ThreadAuthenticator()
        auth.start()
        auth.thread.authenticator.certs['*'] = {
            s['key']['client'].public: 'OK'
        }

        broker = SecureMajorDomoBroker(s['key']['broker'], s['uri'])
    else:
        log.info('Starting mdp-broker on %s', s['uri'])
        broker = MajorDomoBroker(s['uri'])
    try:
        broker.serve_forever()
    except:
        auth.stop()
        raise
예제 #15
0
def main():
    port = "5556"
    socket_ip = "*"
    # ip = socket.getfqdn()

    context = zmq.Context()
    auth = ThreadAuthenticator(context)
    auth.start()

    whitelist = [socket.getfqdn()]
    for host in whitelist:
        hostname, tmp, ip = socket.gethostbyaddr(host)
        auth.allow(ip[0])

    zmq_socket = context.socket(zmq.PUSH)
    zmq_socket.zap_domain = b'global'
    zmq_socket.bind("tcp://" + socket_ip + ":%s" % port)

    try:
        for i in range(5):
            message = ["World"]
            print("Send: ", message)
            res = zmq_socket.send_multipart(message, copy=False, track=True)
            if res.done:
                print("res: done")
            else:
                print("res: waiting")
                res.wait()
                print("res: waiting...")
            print("sleeping...")
            if i == 1:
                auth.stop()
                zmq_socket.close(0)

                auth.start()
                #                ip = socket.gethostbyaddr(socket.getfqdn())[2]
                #                auth.allow(ip[0])
                ip = socket.gethostbyaddr(socket.getfqdn())[2]
                auth.deny(ip[0])
                zmq_socket = context.socket(zmq.PUSH)
                zmq_socket.zap_domain = b'global'
                zmq_socket.bind("tcp://" + socket_ip + ":%s" % port)
            time.sleep(1)
            print("sleeping...done")
            i += 1
    finally:
        auth.stop()
class ContextHandler():
	def __init__(self, publicPath):
		self.__context = zmq.Context()
		self.publicPath = publicPath

		self.auth = ThreadAuthenticator(self.__context)
		self.auth.start()
		self.auth.configure_curve(domain='*', location=self.publicPath)
		self.auth.thread.setName("CurveAuth")

	def getContext(self):
		return self.__context

	def configureAuth(self):
		self.auth.configure_curve(domain='*', location=self.publicPath)

	def cleanup(self):
		self.__context.destroy()
예제 #17
0
    def auth_init():
        """Start an authenticator for this context."""
        from zmq.auth.thread import ThreadAuthenticator
        from jomiel.log import lg

        auth = ThreadAuthenticator(ctx, log=lg())
        auth.start()
        auth.allow(opts.curve_allow)

        # Tell the authenticator to use the client certificates in the
        # specified directory.
        #
        from os.path import abspath

        pubdir = abspath(opts.curve_public_key_dir)
        auth.configure_curve(domain=opts.curve_domain, location=pubdir)

        return auth
예제 #18
0
class CurveAuthenticator(object):
    def __init__(self,
                 ctx,
                 domain='*',
                 location=zmq.auth.CURVE_ALLOW_ANY,
                 callback=None):

        self._domain = domain
        self._location = location
        self._callback = callback
        self._ctx = ctx
        self._atx = ThreadAuthenticator(self.ctx)
        self._atx.start()
        if (self._callback is not None):
            logging.info('Callback: {0}'.format(self._callback))
            self._atx.configure_curve_callback(
                '*', credentials_provider=self._callback)
        elif (self._location == zmq.auth.CURVE_ALLOW_ANY
              or self._location is None):
            self._atx.configure_curve(domain='*',
                                      location=zmq.auth.CURVE_ALLOW_ANY)
        else:
            self.load_certs()

    @property
    def atx(self):
        return self._atx

    @property
    def location(self):
        return self._location

    @property
    def domain(self):
        return self._domain

    @property
    def ctx(self):
        return self._ctx

    def load_certs(self):
        self.atx.configure_curve(domain=self._domain, location=self._location)
예제 #19
0
    def _start_thread_auth(self, socket):
        """
        Start the zmq curve thread authenticator.

        :param socket: The socket in which to configure the authenticator.
        :type socket: zmq.Socket
        """
        authenticator = ThreadAuthenticator(self._factory.context)

        # Temporary fix until we understand what the problem is
        # See https://leap.se/code/issues/7536
        time.sleep(0.5)

        authenticator.start()
        # XXX do not hardcode this here.
        authenticator.allow('127.0.0.1')
        # tell authenticator to use the certificate in a directory
        public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
        authenticator.configure_curve(domain="*", location=public_keys_dir)
        socket.curve_server = True  # must come before bind
    def _init_zmq(self):
        """
        Configure the zmq components and connection.
        """
        context = zmq.Context()
        socket = context.socket(zmq.REP)

        # Start an authenticator for this context.
        auth = ThreadAuthenticator(context)
        auth.start()
        auth.allow('127.0.0.1')

        # Tell authenticator to use the certificate in a directory
        auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
        public, secret = get_backend_certificates()
        socket.curve_publickey = public
        socket.curve_secretkey = secret
        socket.curve_server = True  # must come before bind

        socket.bind(self.BIND_ADDR)

        self._zmq_socket = socket
예제 #21
0
def setup_auth():
    global _auth
    assert _options is not None
    auth = _options.get('auth',None)
    if auth is None:
        return
    base_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)),'..'))
    try:
        _auth = ThreadAuthenticator(_zctx)
        _auth.start()
        whitelist = auth.get('whitelist',None)
        if whitelist is not None:
            _auth.allow(whitelist)
        public_path = auth.get('public_key_dir','public_keys')
        _auth.configure_curve(domain='*',location=getExistsPath(base_dir,public_path))
        private_dir = getExistsPath(base_dir,auth.get('private_key_dir','private_keys'))
        private_key = os.path.join(private_dir,auth.get('private_key_file','server.key_secret'))
        server_public,server_private = zmq.auth.load_certificate(private_key)
        _sock.curve_secretkey = server_private
        _sock.curve_publickey = server_public
        _sock.curve_server = True
    except:
        _auth.stop()
        _auth = None
예제 #22
0
class dataTransfer():
    def __init__(self,
                 connectionType,
                 signalHost=None,
                 useLog=False,
                 context=None):

        if useLog:
            self.log = logging.getLogger("dataTransferAPI")
        elif useLog == None:
            self.log = noLoggingFunction()
        else:
            self.log = loggingFunction()

        # ZMQ applications always start by creating a context,
        # and then using that for creating sockets
        # (source: ZeroMQ, Messaging for Many Applications by Pieter Hintjens)
        if context:
            self.context = context
            self.extContext = True
        else:
            self.context = zmq.Context()
            self.extContext = False

        self.signalHost = signalHost
        self.signalPort = "50000"
        self.requestPort = "50001"
        self.dataHost = None
        self.dataPort = None

        self.signalSocket = None
        self.dataSocket = None
        self.requestSocket = None

        self.poller = zmq.Poller()

        self.auth = None

        self.targets = None

        self.supportedConnections = [
            "stream", "streamMetadata", "queryNext", "queryMetadata"
        ]

        self.signalExchanged = None

        self.streamStarted = None
        self.queryNextStarted = None

        self.socketResponseTimeout = 1000

        if connectionType in self.supportedConnections:
            self.connectionType = connectionType
        else:
            raise NotSupported("Chosen type of connection is not supported.")

    # targets: [host, port, prio] or [[host, port, prio], ...]
    def initiate(self, targets):

        if type(targets) != list:
            self.stop()
            raise FormatError("Argument 'targets' must be list.")

        if not self.context:
            self.context = zmq.Context()
            self.extContext = False

        signal = None
        # Signal exchange
        if self.connectionType == "stream":
            signalPort = self.signalPort
            signal = "START_STREAM"
        elif self.connectionType == "streamMetadata":
            signalPort = self.signalPort
            signal = "START_STREAM_METADATA"
        elif self.connectionType == "queryNext":
            signalPort = self.signalPort
            signal = "START_QUERY_NEXT"
        elif self.connectionType == "queryMetadata":
            signalPort = self.signalPort
            signal = "START_QUERY_METADATA"

        self.log.debug("Create socket for signal exchange...")

        if self.signalHost:
            self.__createSignalSocket(signalPort)
        else:
            self.stop()
            raise ConnectionFailed("No host to send signal to specified.")

        self.__setTargets(targets)

        message = self.__sendSignal(signal)

        if message and message == "VERSION_CONFLICT":
            self.stop()
            raise VersionError("Versions are conflicting.")

        elif message and message == "NO_VALID_HOST":
            self.stop()
            raise AuthenticationFailed("Host is not allowed to connect.")

        elif message and message == "CONNECTION_ALREADY_OPEN":
            self.stop()
            raise CommunicationFailed("Connection is already open.")

        elif message and message == "NO_VALID_SIGNAL":
            self.stop()
            raise CommunicationFailed(
                "Connection type is not supported for this kind of sender.")

        # if there was no response or the response was of the wrong format, the receiver should be shut down
        elif message and message.startswith(signal):
            self.log.info("Received confirmation ...")
            self.signalExchanged = signal

        else:
            raise CommunicationFailed("Sending start signal ...failed.")

    def __createSignalSocket(self, signalPort):

        # To send a notification that a Displayer is up and running, a communication socket is needed
        # create socket to exchange signals with Sender
        self.signalSocket = self.context.socket(zmq.REQ)

        # time to wait for the sender to give a confirmation of the signal
        #        self.signalSocket.RCVTIMEO = self.socketResponseTimeout
        connectionStr = "tcp://" + str(self.signalHost) + ":" + str(signalPort)
        try:
            self.signalSocket.connect(connectionStr)
            self.log.info("signalSocket started (connect) for '" +
                          connectionStr + "'")
        except:
            self.log.error("Failed to start signalSocket (connect): '" +
                           connectionStr + "'")
            raise

        # using a Poller to implement the signalSocket timeout (in older ZMQ version there is no option RCVTIMEO)
        self.poller.register(self.signalSocket, zmq.POLLIN)

    def __setTargets(self, targets):
        self.targets = []

        # [host, port, prio]
        if len(targets) == 3 and type(targets[0]) != list and type(
                targets[1]) != list and type(targets[2]) != list:
            host, port, prio = targets
            self.targets = [[host + ":" + port, prio, [""]]]

        # [host, port, prio, suffixes]
        elif len(targets) == 4 and type(targets[0]) != list and type(
                targets[1]) != list and type(targets[2]) != list and type(
                    targets[3]) == list:
            host, port, prio, suffixes = targets
            self.targets = [[host + ":" + port, prio, suffixes]]

        # [[host, port, prio], ...] or [[host, port, prio, suffixes], ...]
        else:
            for t in targets:
                if type(t) == list and len(t) == 3:
                    host, port, prio = t
                    self.targets.append([host + ":" + port, prio, [""]])
                elif type(t) == list and len(t) == 4 and type(t[3]):
                    host, port, prio, suffixes = t
                    self.targets.append([host + ":" + port, prio, suffixes])
                else:
                    self.stop()
                    self.log.debug("targets=" + str(targets))
                    raise FormatError("Argument 'targets' is of wrong format.")

    def __sendSignal(self, signal):

        if not signal:
            return

        # Send the signal that the communication infrastructure should be established
        self.log.info("Sending Signal")

        sendMessage = [__version__, signal]

        trg = cPickle.dumps(self.targets)
        sendMessage.append(trg)

        #        sendMessage = [__version__, signal, self.dataHost, self.dataPort]

        self.log.debug("Signal: " + str(sendMessage))
        try:
            self.signalSocket.send_multipart(sendMessage)
        except:
            self.log.error("Could not send signal")
            raise

        message = None
        try:
            socks = dict(self.poller.poll(self.socketResponseTimeout))
        except:
            self.log.error("Could not poll for new message")
            raise

        # if there was a response
        if self.signalSocket in socks and socks[
                self.signalSocket] == zmq.POLLIN:
            try:
                #  Get the reply.
                message = self.signalSocket.recv()
                self.log.info("Received answer to signal: " + str(message))

            except:
                self.log.error("Could not receive answer to signal")
                raise

        return message

    def start(self, dataSocket=False, whitelist=None):

        # Receive data only from whitelisted nodes
        if whitelist:
            if type(whitelist) == list:
                self.auth = ThreadAuthenticator(self.context)
                self.auth.start()
                for host in whitelist:
                    try:
                        if host == "localhost":
                            ip = [socket.gethostbyname(host)]
                        else:
                            hostname, tmp, ip = socket.gethostbyaddr(host)

                        self.log.debug("Allowing host " + host + " (" +
                                       str(ip[0]) + ")")
                        self.auth.allow(ip[0])
                    except:
                        self.log.error("Error was: ", exc_info=True)
                        raise AuthenticationFailed(
                            "Could not get IP of host " + host)
            else:
                raise FormatError("Whitelist has to be a list of IPs")

        socketIdToConnect = self.streamStarted or self.queryNextStarted

        if socketIdToConnect:
            self.log.info("Reopening already started connection.")
        else:

            ip = "0.0.0.0"  #TODO use IP of hostname?

            host = ""
            port = ""

            if dataSocket:
                if type(dataSocket) == list:
                    socketIdToConnect = dataSocket[0] + ":" + dataSocket[1]
                    host = dataSocket[0]
                    ip = socket.gethostbyaddr(host)[2][0]
                    port = dataSocket[1]
                else:
                    port = str(dataSocket)

                    host = socket.gethostname()
                    socketId = host + ":" + port
                    ipFromHost = socket.gethostbyaddr(host)[2]
                    if len(ipFromHost) == 1:
                        ip = ipFromHost[0]

            elif len(self.targets) == 1:
                host, port = self.targets[0][0].split(":")
                ipFromHost = socket.gethostbyaddr(host)[2]
                if len(ipFromHost) == 1:
                    ip = ipFromHost[0]

            else:
                raise FormatError(
                    "Multipe possible ports. Please choose which one to use.")

            socketId = host + ":" + port
            socketIdToConnect = ip + ":" + port
#            socketIdToConnect = "[" + ip + "]:" + port

        self.dataSocket = self.context.socket(zmq.PULL)
        # An additional socket is needed to establish the data retriving mechanism
        connectionStr = "tcp://" + socketIdToConnect
        if whitelist:
            self.dataSocket.zap_domain = b'global'

        try:
            #            self.dataSocket.ipv6 = True
            self.dataSocket.bind(connectionStr)
            #            self.dataSocket.bind("tcp://[2003:ce:5bc0:a600:fa16:54ff:fef4:9fc0]:50102")
            self.log.info("Data socket of type " + self.connectionType +
                          " started (bind) for '" + connectionStr + "'")
        except:
            self.log.error("Failed to start Socket of type " +
                           self.connectionType + " (bind): '" + connectionStr +
                           "'",
                           exc_info=True)
            raise

        self.poller.register(self.dataSocket, zmq.POLLIN)

        if self.connectionType in ["queryNext", "queryMetadata"]:

            self.requestSocket = self.context.socket(zmq.PUSH)
            # An additional socket is needed to establish the data retriving mechanism
            connectionStr = "tcp://" + self.signalHost + ":" + self.requestPort
            try:
                self.requestSocket.connect(connectionStr)
                self.log.info("Request socket started (connect) for '" +
                              connectionStr + "'")
            except:
                self.log.error("Failed to start Socket of type " +
                               self.connectionType + " (connect): '" +
                               connectionStr + "'",
                               exc_info=True)
                raise

            self.queryNextStarted = socketId
        else:
            self.streamStarted = socketId

    ##
    #
    # Receives or queries for new files depending on the connection initialized
    #
    # returns either
    #   the newest file
    #       (if connection type "queryNext" or "stream" was choosen)
    #   the path of the newest file
    #       (if connection type "queryMetadata" or "streamMetadata" was choosen)
    #
    ##
    def get(self, timeout=None):

        if not self.streamStarted and not self.queryNextStarted:
            self.log.info(
                "Could not communicate, no connection was initialized.")
            return None, None

        if self.queryNextStarted:

            sendMessage = ["NEXT", self.queryNextStarted]
            try:
                self.requestSocket.send_multipart(sendMessage)
            except Exception as e:
                self.log.error("Could not send request to requestSocket",
                               exc_info=True)
                return None, None

        while True:
            # receive data
            if timeout:
                try:
                    socks = dict(self.poller.poll(timeout))
                except:
                    self.log.error("Could not poll for new message")
                    raise
            else:
                try:
                    socks = dict(self.poller.poll())
                except:
                    self.log.error("Could not poll for new message")
                    raise

            # if there was a response
            if self.dataSocket in socks and socks[
                    self.dataSocket] == zmq.POLLIN:

                try:
                    multipartMessage = self.dataSocket.recv_multipart()
                except:
                    self.log.error("Receiving data..failed.", exc_info=True)
                    return [None, None]

                if multipartMessage[0] == b"ALIVE_TEST":
                    continue
                elif len(multipartMessage) < 2:
                    self.log.error(
                        "Received mutipart-message is too short. Either config or file content is missing."
                    )
                    self.log.debug("multipartMessage=" +
                                   str(mutipartMessage)[:100])
                    return [None, None]

                # extract multipart message
                try:
                    metadata = cPickle.loads(multipartMessage[0])
                except:
                    self.log.error(
                        "Could not extract metadata from the multipart-message.",
                        exc_info=True)
                    metadata = None

                #TODO validate multipartMessage (like correct dict-values for metadata)

                try:
                    payload = multipartMessage[1]
                except:
                    self.log.warning(
                        "An empty file was received within the multipart-message",
                        exc_info=True)
                    payload = None

                return [metadata, payload]
            else:
                self.log.warning("Could not receive data in the given time.")

                if self.queryNextStarted:
                    try:
                        self.requestSocket.send_multipart(
                            ["CANCEL", self.queryNextStarted])
                    except Exception as e:
                        self.log.error("Could not cancel the next query",
                                       exc_info=True)

                return [None, None]

    def store(self, targetBasePath, dataObject):

        if type(dataObject) is not list and len(dataObject) != 2:
            raise FormatError("Wrong input type for 'store'")

        payloadMetadata = dataObject[0]
        payload = dataObject[1]

        if type(payloadMetadata) is not dict:
            raise FormatError("payload: Wrong input format in 'store'")

        #save all chunks to file
        while True:

            #TODO check if payload != cPickle.dumps(None) ?
            if payloadMetadata and payload:
                #append to file
                try:
                    self.log.debug(
                        "append to file based on multipart-message...")
                    #TODO: save message to file using a thread (avoids blocking)
                    #TODO: instead of open/close file for each chunk recyle the file-descriptor for all chunks opened
                    self.__appendChunksToFile(targetBasePath, payloadMetadata,
                                              payload)
                    self.log.debug(
                        "append to file based on multipart-message...success.")
                except KeyboardInterrupt:
                    self.log.info(
                        "KeyboardInterrupt detected. Unable to append multipart-content to file."
                    )
                    break
                except Exception, e:
                    self.log.error(
                        "Unable to append multipart-content to file.",
                        exc_info=True)
                    self.log.debug(
                        "Append to file based on multipart-message...failed.")

                if len(payload) < payloadMetadata["chunkSize"]:
                    #indicated end of file. Leave loop
                    filename = self.generateTargetFilepath(
                        targetBasePath, payloadMetadata)
                    fileModTime = payloadMetadata["fileModTime"]

                    self.log.info("New file with modification time " +
                                  str(fileModTime) + " received and saved: " +
                                  str(filename))
                    break

            try:
                [payloadMetadata, payload] = self.get()
            except:
                self.log.error("Getting data failed.", exc_info=True)
                break
예제 #23
0
class MultiNodeAgent(BEMOSSAgent):
    def __init__(self, *args, **kwargs):
        super(MultiNodeAgent, self).__init__(*args, **kwargs)
        self.multinode_status = dict()
        self.getMultinodeData()
        self.agent_id = 'multinodeagent'
        self.is_parent = False
        self.last_sync_with_parent = datetime(1991, 1,
                                              1)  #equivalent to -ve infinitive
        self.parent_node = None
        self.recently_online_node_list = []  # initialize to lists to empty
        self.recently_offline_node_list = [
        ]  # they will be filled as nodes are discovered to be online/offline
        self.setup()

        self.runPeriodically(self.send_heartbeat, 20)
        self.runPeriodically(self.check_health, 60, start_immediately=False)
        self.runPeriodically(self.sync_all_with_parent, 600)
        self.subscribe('relay_message', self.relayDirectMessage)
        self.subscribe('update_multinode_data', self.updateMultinodeData)
        self.runContinuously(self.pollClients)
        self.run()

    def getMultinodeData(self):
        self.multinode_data = db_helper.get_multinode_data()

        self.nodelist_dict = {
            node['name']: node
            for node in self.multinode_data['known_nodes']
        }
        self.node_name_list = [
            node['name'] for node in self.multinode_data['known_nodes']
        ]
        self.address_list = [
            node['address'] for node in self.multinode_data['known_nodes']
        ]
        self.server_key_list = [
            node['server_key'] for node in self.multinode_data['known_nodes']
        ]
        self.node_name = self.multinode_data['this_node']

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == self.node_name:
                self.node_index = index
                break
        else:
            raise ValueError(
                '"this_node:" entry on the multinode_data json file is invalid'
            )

        for node_name in self.node_name_list:  #initialize all nodes data
            if node_name not in self.multinode_status:  #initialize new nodes. There could be already the node if this getMultiNode
                # data is being called later
                self.multinode_status[node_name] = dict()
                self.multinode_status[node_name][
                    'health'] = -10  #initialized; never online/offline
                self.multinode_status[node_name]['last_sync_time'] = datetime(
                    1991, 1, 1)
                self.multinode_status[node_name]['last_online_time'] = None
                self.multinode_status[node_name]['last_offline_time'] = None
                self.multinode_status[node_name]['last_scanned_time'] = None

    def setup(self):
        print "Setup"

        base_dir = settings.PROJECT_DIR + "/"
        public_keys_dir = os.path.abspath(os.path.join(base_dir,
                                                       'public_keys'))
        secret_keys_dir = os.path.abspath(
            os.path.join(base_dir, 'private_keys'))

        self.secret_keys_dir = secret_keys_dir
        self.public_keys_dir = public_keys_dir

        if not (os.path.exists(public_keys_dir)
                and os.path.exists(secret_keys_dir)):
            logging.critical(
                "Certificates are missing - run generate_certificates.py script first"
            )
            sys.exit(1)

        ctx = zmq.Context.instance()
        self.ctx = ctx
        # Start an authenticator for this context.
        self.auth = ThreadAuthenticator(ctx)
        self.auth.start()
        self.configure_authenticator()

        server = ctx.socket(zmq.PUB)

        server_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['server_secret_key']
        server_secret_file = os.path.join(secret_keys_dir,
                                          server_secret_key_filename)
        server_public, server_secret = zmq.auth.load_certificate(
            server_secret_file)
        server.curve_secretkey = server_secret
        server.curve_publickey = server_public
        server.curve_server = True  # must come before bind
        server.bind(
            self.multinode_data['known_nodes'][self.node_index]['address'])
        self.server = server
        self.configureClient()

    def configure_authenticator(self):
        self.auth.allow()
        # Tell authenticator to use the certificate in a directory
        self.auth.configure_curve(domain='*', location=self.public_keys_dir)

    def disperseMessage(self, sender, topic, message):
        for node_name in self.node_name_list:
            if node_name == self.node_name:
                continue
            self.server.send(
                jsonify(sender, node_name + '/republish/' + topic, message))

    def republishToParent(self, sender, topic, message):
        if self.is_parent:
            return  #if I am parent, the message is already published
        for node_name in self.node_name_list:
            if self.multinode_status[node_name][
                    'health'] == 2:  #health = 2 is the parent node
                self.server.send(
                    jsonify(sender, node_name + '/republish/' + topic,
                            message))

    def sync_node_with_parent(self, node_name):
        if self.is_parent:
            print "Syncing " + node_name
            self.last_sync_with_parent = datetime.now()
            sync_date_string = self.last_sync_with_parent.strftime(
                '%B %d, %Y, %H:%M:%S')
            # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path)
            # with open(self.self_database_dump_path, 'r') as f:
            #     file_content = f.read()
            # msg = {'database_dump': base64.b64encode(file_content)}
            self.server.send(
                jsonify(
                    self.node_name, node_name + '/sync-with-parent/' +
                    sync_date_string + '/' + self.node_name, ""))

    def sync_all_with_parent(self, dbcon):

        if self.is_parent:
            self.last_sync_with_parent = datetime.now()
            sync_date_string = self.last_sync_with_parent.strftime(
                '%B %d, %Y, %H:%M:%S')
            print "Syncing all nodes"
            for node_name in self.node_name_list:
                if node_name == self.node_name:
                    continue
                # os.system('pg_dump bemossdb -f ' + self.self_database_dump_path)
                # with open(self.self_database_dump_path, 'r') as f:
                #     file_content = f.read()
                # msg = {'database_dump': base64.b64encode(file_content)}
                self.server.send(
                    jsonify(
                        self.node_name, node_name + '/sync-with-parent/' +
                        sync_date_string + '/' + self.node_name, ""))

    def send_heartbeat(self, dbcon):
        #self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'})
        #print 'publishing'
        print "Sending heartbeat"

        last_sync_string = self.last_sync_with_parent.strftime(
            '%B %d, %Y, %H:%M:%S')
        self.server.send(
            jsonify(
                self.node_name, 'heartbeat/' + self.node_name + '/' +
                str(self.is_parent) + '/' + last_sync_string, ""))

    def extract_ip(self, addr):
        return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})',
                         addr).groups()[0]

    def getNodeId(self, node_name):

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == node_name:
                node_index = index
                break
        else:
            raise ValueError('the node name: ' + node_name +
                             ' is not found in multinode data')

        return node_index

    def getNodeName(self, node_id):
        return self.multinode_data['known_nodes'][node_id]['name']

    def handle_offline_nodes(self, dbcon, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that node on this node
            command_group = []
            for node_name in node_name_list:
                node_id = self.getNodeId(node_name)
                #put the offline event into cassandra events log table, and also create notification
                self.EventRegister(dbcon,
                                   'node-offline',
                                   reason='communication-error',
                                   source=node_name)
                # get a list of agents that were supposedly running in that offline node
                dbcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))

                if dbcon.rowcount:
                    agent_ids = dbcon.fetchall()

                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.AGENT_STATUS] = 'start'
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY
                        command_group += [message]
                        dbcon.execute(
                            "UPDATE " + node_devices_table +
                            " SET current_node_id=(%s), date_move=(%s)"
                            " WHERE agent_id=(%s)",
                            (self.node_index, datetime.now(
                                pytz.UTC), agent_id[0]))
                        dbcon.commit()
            print "moving agents from offline node to parent: " + str(
                node_name_list)
            print command_group
            if command_group:
                self.bemoss_publish(target='networkagent',
                                    topic='status_change',
                                    message=command_group)

    def handle_online_nodes(self, dbcon, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that nodes back on them
            command_group = []
            for node_name in node_name_list:

                node_id = self.getNodeId(node_name)
                if self.node_index == node_id:
                    continue  #don't handle self-online
                self.EventRegister(dbcon,
                                   'node-online',
                                   reason='communication-restored',
                                   source=node_name)

                #get a list of agents that were supposed to be running in that online node
                dbcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))
                if dbcon.rowcount:
                    agent_ids = dbcon.fetchall()
                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.
                                AGENT_STATUS] = 'stop'  #stop in this node
                        command_group += [message]
                        message = dict(message)  #create another copy
                        message[STATUS_CHANGE.NODE] = str(node_id)
                        message[
                            STATUS_CHANGE.
                            AGENT_STATUS] = 'start'  #start in the target node
                        command_group += [message]
                        #immediately update the multnode device assignment table
                        dbcon.execute(
                            "UPDATE " + node_devices_table +
                            " SET current_node_id=(%s), date_move=(%s)"
                            " WHERE agent_id=(%s)",
                            (node_id, datetime.now(pytz.UTC), agent_id[0]))
                        dbcon.commit()

            print "Moving agents back to the online node: " + str(
                node_name_list)
            print command_group

            if command_group:
                self.bemoss_publish(target='networkagent',
                                    topic='status_change',
                                    message=command_group)

    def updateParent(self, dbcon, parent_node_name):
        parent_ip = self.extract_ip(
            self.nodelist_dict[parent_node_name]['address'])
        write_new = False
        if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE
                              ):  # but parent file doesn't exists
            write_new = True
        else:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f:
                read_ip = f.read()
            if read_ip != parent_ip:
                write_new = True
        if write_new:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f:
                f.write(parent_ip)
            if dbcon:
                dbcon.close()  #close old connection
            dbcon = db_helper.db_connection(
            )  #start new connection using new parent_ip
            self.updateMultinodeData(sender=self.name,
                                     topic='update_parent',
                                     message="")

    def check_health(self, dbcon):

        for node_name, node in self.multinode_status.items():
            if node['health'] > 0:  #initialize all online nodes to 0. If they are really online, they should change it
                #  back to 1 or 2 (parent) within 30 seconds throught the heartbeat.
                node['health'] = 0

        time.sleep(30)
        parent_node_name = None  #initialize parent node
        online_node_exists = False
        for node_name, node in self.multinode_status.items():
            node['last_scanned_time'] = datetime.now()
            if node['health'] == 0:
                node['health'] = -1
                node['last_offline_time'] = datetime.now()
                self.recently_offline_node_list += [node_name]
            elif node['health'] == -1:  #offline since long
                pass
            elif node[
                    'health'] == -10:  #The node was initialized to -10, and never came online. Treat it as recently going
                # offline for this iteration so that the agents that were supposed to be running there can be migrated
                node['health'] = -1
                self.recently_offline_node_list += [node_name]
            elif node['health'] == 2:  #there is some parent node present
                parent_node_name = node_name
            if node['health'] > 0:
                online_node_exists = True  #At-least one node (itself) should be online, if not some problem

        assert online_node_exists, "At least one node (current node) must be online"

        if not parent_node_name:  #parent node doesn't exist
            #find a suitable node to elect a parent. The node with latest update from previous parent wins. If there is
            #tie, then the node coming earlier in the node-list on multinode data wins

            online_node_last_sync = dict(
            )  #only the online nodes, and their last-sync-times
            for node_name, node in self.multinode_status.items(
            ):  #copy only the online nodes
                if node['health'] > 0:
                    online_node_last_sync[node_name] = node['last_sync_time']

            latest_node = max(online_node_last_sync,
                              key=online_node_last_sync.get)
            latest_sync_date = online_node_last_sync[latest_node]

            for node_name in self.node_name_list:
                if self.multinode_status[node_name][
                        'health'] <= 0:  #dead nodes can't be parents
                    continue
                if self.multinode_status[node_name][
                        'last_sync_time'] == latest_sync_date:  # this is the first node with the latest update from parent
                    #elligible parent found
                    self.updateParent(dbcon, node_name)

                    if node_name == self.node_name:  # I am the node, so I get to become the parent
                        self.is_parent = True
                        print "I am the boss now, " + self.node_name
                        break
                    else:  #I-am-not-the-first-node with latest update; somebody else is
                        self.is_parent = False
                        break
        else:  #parent node exist
            self.updateParent(dbcon, parent_node_name)

        for node in self.multinode_data['known_nodes']:
            print node['name'] + ': ' + str(
                self.multinode_status[node['name']]['health'])

        if self.is_parent:
            #if this is a parent node, update the node_info table
            if dbcon is None:  #if no database connection exists make connection
                dbcon = db_helper.db_connection()

            tbl_node_info = settings.DATABASES['default']['TABLE_node_info']
            dbcon.execute('select node_id from ' + tbl_node_info)
            to_be_deleted_node_ids = dbcon.fetchall()
            for index, node in enumerate(self.multinode_data['known_nodes']):
                if (index, ) in to_be_deleted_node_ids:
                    to_be_deleted_node_ids.remove(
                        (index, ))  #don't remove this current node
                result = dbcon.execute(
                    'select * from ' + tbl_node_info + ' where node_id=%s',
                    (index, ))
                node_type = 'parent' if self.multinode_status[
                    node['name']]['health'] == 2 else "child"
                node_status = "ONLINE" if self.multinode_status[
                    node['name']]['health'] > 0 else "OFFLINE"
                ip_address = self.extract_ip(node['address'])
                last_scanned_time = self.multinode_status[
                    node['name']]['last_online_time']
                last_offline_time = self.multinode_status[
                    node['name']]['last_offline_time']
                last_sync_time = self.multinode_status[
                    node['name']]['last_sync_time']

                var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)"
                value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)"
                actual_values_list = (index, node['name'], node_type,
                                      node_status, ip_address,
                                      last_scanned_time, last_offline_time,
                                      last_sync_time)

                if dbcon.rowcount == 0:
                    dbcon.execute(
                        "insert into " + tbl_node_info + " " + var_list +
                        " VALUES" + value_placeholder_list, actual_values_list)
                else:
                    dbcon.execute(
                        "update " + tbl_node_info + " SET " + var_list +
                        " = " + value_placeholder_list + " where node_id = %s",
                        actual_values_list + (index, ))
            dbcon.commit()

            for id in to_be_deleted_node_ids:
                dbcon.execute(
                    'delete from accounts_userprofile_nodes where nodeinfo_id=%s',
                    id)  #delete entries in user-profile for the old node
                dbcon.commit()
                dbcon.execute('delete from ' + tbl_node_info +
                              ' where node_id=%s', id)  #delete the old nodes
                dbcon.commit()

            if self.recently_online_node_list:  #Online nodes should be handled first because, the same node can first be
                #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after
                #coming online
                self.handle_online_nodes(dbcon, self.recently_online_node_list)
                self.recently_online_node_list = []  # reset after handling
            if self.recently_offline_node_list:
                self.handle_offline_nodes(dbcon,
                                          self.recently_offline_node_list)
                self.recently_offline_node_list = []  # reset after handling

    def connect_client(self, node):
        server_public_file = os.path.join(self.public_keys_dir,
                                          node['server_key'])
        server_public, _ = zmq.auth.load_certificate(server_public_file)
        # The client must know the server's public key to make a CURVE connection.
        self.client.curve_serverkey = server_public
        self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/')
        self.client.setsockopt(zmq.SUBSCRIBE, self.node_name)
        self.client.connect(node['address'])

    def disconnect_client(self, node):
        self.client.disconnect(node['address'])

    def configureClient(self):
        print "Starting to receive Heart-beat"
        client = self.ctx.socket(zmq.SUB)
        # We need two certificates, one for the client and one for
        # the server. The client must know the server's public key
        # to make a CURVE connection.

        client_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['client_secret_key']
        client_secret_file = os.path.join(self.secret_keys_dir,
                                          client_secret_key_filename)
        client_public, client_secret = zmq.auth.load_certificate(
            client_secret_file)
        client.curve_secretkey = client_secret
        client.curve_publickey = client_public

        self.client = client

        for node in self.multinode_data['known_nodes']:
            self.connect_client(node)

    def pollClients(self, dbcon):
        if self.client.poll(1000):
            sender, topic, msg = dejsonify(self.client.recv())
            topic_list = topic.split('/')
            if topic_list[0] == 'heartbeat':
                node_name = sender
                is_parent = topic_list[2]
                last_sync_with_parent = topic_list[3]
                if self.multinode_status[node_name][
                        'health'] < 0:  #the node health was <0 , means offline
                    print node_name + " is back online"
                    self.recently_online_node_list += [node_name]
                    self.sync_node_with_parent(node_name)

                if is_parent.lower() in ['false', '0']:
                    self.multinode_status[node_name]['health'] = 1
                elif is_parent.lower() in ['true', '1']:
                    self.multinode_status[node_name]['health'] = 2
                    self.parent_node = node_name
                else:
                    raise ValueError(
                        'Invalid is_parent string in heart-beat message')

                self.multinode_status[node_name][
                    'last_online_time'] = datetime.now()
                self.multinode_status[node_name][
                    'last_sync_time'] = datetime.strptime(
                        last_sync_with_parent, '%B %d, %Y, %H:%M:%S')

            if topic_list[0] == self.node_name:
                if topic_list[1] == 'sync-with-parent':
                    pass
                    # print topic
                    # self.last_sync_with_parent = datetime.strptime(topic_list[2], '%B %d, %Y, %H:%M:%S')
                    # content = base64.b64decode(msg['database_dump'])
                    # newpath = 'bemossdb.sql'
                    # with open(newpath, 'w') as f:
                    #     f.write(content)
                    # try:
                    #     os.system(
                    #         'psql -c "SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid();"')
                    #     os.system(
                    #         'dropdb bemossdb')  # This step requires all connections to be closed
                    #     os.system('createdb bemossdb -O admin')
                    #     dump_result = subprocess.check_output('psql bemossdb < ' + newpath, shell=True)
                    # except Exception as er:
                    #     print "Couldn't sync database with parent because of error: "
                    #     print er
                    #
                    # parent_node_name = topic_list[3]
                    # self.updateParent(parent_node_name)

                if topic_list[1] == 'republish':
                    target = msg['target']
                    actual_message = msg['actual_message']
                    actual_topic = msg['actual_topic']
                    self.bemoss_publish(target=target,
                                        topic=actual_topic + '/republished',
                                        message=actual_message,
                                        sender=sender)

            print self.node_name + ": " + topic, str(msg)

        else:
            time.sleep(2)

    def cleanup(self):
        # stop auth thread
        self.auth.stop()

    def updateMultinodeData(self, dbcon, sender, topic, message):
        print "Updating Multinode data"
        topic_list = topic.split('/')
        self.configure_authenticator()
        #to/multinodeagent/from/<doesn't matter>/update_multinode_data
        if topic_list[4] == 'update_multinode_data':
            old_multinode_data = self.multinode_data
            self.getMultinodeData()
            for node in self.multinode_data['known_nodes']:
                if node not in old_multinode_data['known_nodes']:
                    print "New node has been added to the cluster: " + node[
                        'name']
                    print "We will connect to it"
                    self.connect_client(node)

            for node in old_multinode_data['known_nodes']:
                if node not in self.multinode_data['known_nodes']:
                    print "Node has been removed from the cluster: " + node[
                        'name']
                    print "We will disconnect from it"
                    self.disconnect_client(node)
                    # TODO: remove it from the node_info table

        print "yay! got it"

    def relayDirectMessage(self, dbcon, sender, topic, message):
        print topic
        #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui>

        from_entity = sender
        target = message['target']
        actual_message = message['actual_message']
        actual_topic = message['actual_topic']

        for to_entity in target:
            if to_entity in settings.NO_FORWARD_AGENTS:
                return  #no forwarding should be done for these agents
            elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS:
                if not self.is_parent:
                    self.republishToParent(sender, topic, message)
            elif to_entity == "ALL":
                self.disperseMessage(sender, topic=topic, message=message)
            else:
                dbcon.execute(
                    "SELECT current_node_id FROM " + node_devices_table +
                    " WHERE agent_id=%s", (to_entity, ))
                if dbcon.rowcount:
                    node_id = dbcon.fetchone()[0]
                    if node_id != self.node_index:
                        self.server.send(
                            jsonify(
                                sender,
                                self.getNodeName(node_id) + '/republish/' +
                                topic, message))
                else:
                    self.disperseMessage(
                        sender, topic, message
                    )  #republish to all nodes if we don't know where to send
예제 #24
0
def test_encryption(tmpdir):
    # Create the tmp names
    conf_filename = str(tmpdir.join("conf.yaml"))
    pull_url = tmpdir.join("input.pull.socket")
    pull_cert_dir = tmpdir.mkdir("input.pull")
    pull_clients_cert_dir = pull_cert_dir.mkdir("clients")
    sub_url = tmpdir.join("input.sub.socket")
    sub_cert_dir = tmpdir.mkdir("input.sub")
    push_url = tmpdir.join("output.push.socket")
    inbound = tmpdir.join("inbound")
    outbound = tmpdir.join("outbound")
    stdout = tmpdir.join("stdout")
    stderr = tmpdir.join("stderr")

    # Create the certificates
    create_certificates(str(pull_cert_dir), "pull")
    create_certificates(str(pull_clients_cert_dir), "client1")
    create_certificates(str(pull_clients_cert_dir), "client2")
    create_certificates(str(sub_cert_dir), "sub")
    create_certificates(str(sub_cert_dir), "sub-server")

    with open(conf_filename, "w") as f:
        f.write("inputs:\n")
        f.write("- class: ZMQPull\n")
        f.write("  name: in-pull\n")
        f.write("  options:\n")
        f.write("    url: ipc://%s\n" % pull_url)
        f.write("    encryption:\n")
        f.write("      self: %s\n" % pull_cert_dir.join("pull.key_secret"))
        f.write("      clients: %s\n" % pull_clients_cert_dir)
        f.write("- class: ZMQSub\n")
        f.write("  name: in-sub\n")
        f.write("  options:\n")
        f.write("    url: ipc://%s\n" % sub_url)
        f.write("    encryption:\n")
        f.write("      self: %s\n" % sub_cert_dir.join("sub.key_secret"))
        f.write("      server: %s\n" % sub_cert_dir.join("sub-server.key"))
        f.write("core:\n")
        f.write("  inbound: ipc://%s\n" % inbound)
        f.write("  outbound: ipc://%s\n" % outbound)
        f.write("outputs:\n")
        f.write("- class: ZMQPush\n")
        f.write("  name: out-push\n")
        f.write("  options:\n")
        f.write("    url: ipc://%s\n" % push_url)
    args = [
        "python3",
        "-m",
        "reactobus",
        "--conf",
        conf_filename,
        "--level",
        "DEBUG",
        "--log-file",
        "-",
    ]
    proc = subprocess.Popen(args,
                            stdout=open(str(stdout), "w"),
                            stderr=open(str(stderr), "w"))

    # Create the input sockets
    ctx = zmq.Context.instance()
    in_sock = ctx.socket(zmq.PUSH)
    (server_public, _) = load_certificate(str(pull_cert_dir.join("pull.key")))
    in_sock.curve_serverkey = server_public
    (client_public, client_private) = load_certificate(
        str(pull_clients_cert_dir.join("client1.key_secret")))
    in_sock.curve_publickey = client_public
    in_sock.curve_secretkey = client_private
    in_sock.connect("ipc://%s" % pull_url)

    out_sock = ctx.socket(zmq.PULL)
    out_sock.bind("ipc://%s" % push_url)

    pub_sock = ctx.socket(zmq.PUB)
    auth = ThreadAuthenticator(ctx)
    auth.start()
    auth.configure_curve(domain="*", location=str(sub_cert_dir))
    (server_public, server_secret) = load_certificate(
        str(sub_cert_dir.join("sub-server.key_secret")))
    pub_sock.curve_publickey = server_public
    pub_sock.curve_secretkey = server_secret
    pub_sock.curve_server = True
    pub_sock.bind("ipc://%s" % sub_url)

    # Allow the process sometime to setup and connect
    time.sleep(1)

    # Send some data
    data = [
        b"org.videolan.git",
        b(str(uuid.uuid1())),
        b(datetime.datetime.utcnow().isoformat()),
        b("videolan-git"),
        b(
            json.dumps({
                "url": "https://code.videolan.org/éêï",
                "username": "******"
            })),
    ]
    in_sock.send_multipart(data)
    msg = out_sock.recv_multipart()
    assert msg == data

    data = [
        b"org.videolan.git",
        b(str(uuid.uuid1())),
        b(datetime.datetime.utcnow().isoformat()),
        b("videolan-git"),
        b(
            json.dumps({
                "url": "https://code.videolan.org/éêï",
                "username": "******"
            })),
    ]
    pub_sock.send_multipart(data)
    msg = out_sock.recv_multipart()
    assert msg == data

    # End the process
    proc.terminate()
    proc.wait()
예제 #25
0
  def run(self):
    self.set_status("Client Startup")
    self.set_status("Creating zmq Contexts",1)
    clientctx = zmq.Context() 
    self.set_status("Starting zmq ThreadedAuthenticator",1)
    #clientauth = zmq.auth.ThreadedAuthenticator(clientctx)
    clientauth = ThreadAuthenticator(clientctx)
    clientauth.start()
    
    with taco.globals.settings_lock:
      publicdir  = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/public/"))
      privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/private/"))

    self.set_status("Configuring Curve to use publickey dir:" + publicdir)
    clientauth.configure_curve(domain='*', location=publicdir)
    
    poller = zmq.Poller()
    while not self.stop.is_set():
      #logging.debug("PRE")
      result = self.sleep.wait(0.1)
      #logging.debug(result)
      self.sleep.clear()
      if self.stop.is_set(): break

      if abs(time.time() - self.connect_block_time) > 1:
        with taco.globals.settings_lock: self.max_upload_rate   = taco.globals.settings["Upload Limit"] * taco.constants.KB
        with taco.globals.settings_lock: self.max_download_rate = taco.globals.settings["Download Limit"] * taco.constants.KB
        self.chunk_request_rate = float(taco.constants.FILESYSTEM_CHUNK_SIZE) / float(self.max_download_rate)
        #logging.debug(str((self.max_download_rate,taco.constants.FILESYSTEM_CHUNK_SIZE,self.chunk_request_rate)))
        self.connect_block_time = time.time() 
        with taco.globals.settings_lock:
          for peer_uuid in taco.globals.settings["Peers"].keys():
            if taco.globals.settings["Peers"][peer_uuid]["enabled"]:
              #init some defaults
              if not peer_uuid in self.client_reconnect_mod: self.client_reconnect_mod[peer_uuid] = taco.constants.CLIENT_RECONNECT_MIN
              if not peer_uuid in self.client_connect_time:  self.client_connect_time[peer_uuid]  = time.time() + self.client_reconnect_mod[peer_uuid]
              if not peer_uuid in self.client_timeout:       self.client_timeout[peer_uuid]       = time.time() + taco.constants.ROLLCALL_TIMEOUT

              if time.time() >= self.client_connect_time[peer_uuid]:
                if peer_uuid not in self.clients.keys():
                  self.set_status("Starting Client for: " + peer_uuid)
                  try:
                    ip_of_client = socket.gethostbyname(taco.globals.settings["Peers"][peer_uuid]["hostname"])
                  except:
                    self.set_status("Starting of client failed due to bad dns lookup:" + peer_uuid)
                    continue
                  self.clients[peer_uuid] = clientctx.socket(zmq.DEALER)
                  self.clients[peer_uuid].setsockopt(zmq.LINGER, 0)
                  client_public, client_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-client.key_secret")))
                  self.clients[peer_uuid].curve_secretkey = client_secret
                  self.clients[peer_uuid].curve_publickey = client_public
                  self.clients[peer_uuid].curve_serverkey = str(taco.globals.settings["Peers"][peer_uuid]["serverkey"])
                  self.clients[peer_uuid].connect("tcp://" + ip_of_client + ":" + str(taco.globals.settings["Peers"][peer_uuid]["port"]))
                  self.next_rollcall[peer_uuid] = time.time()

                  with taco.globals.high_priority_output_queue_lock:   taco.globals.high_priority_output_queue[peer_uuid]   = Queue.Queue()
                  with taco.globals.medium_priority_output_queue_lock: taco.globals.medium_priority_output_queue[peer_uuid] = Queue.Queue()
                  with taco.globals.low_priority_output_queue_lock:    taco.globals.low_priority_output_queue[peer_uuid]    = Queue.Queue()
                  with taco.globals.file_request_output_queue_lock:    taco.globals.file_request_output_queue[peer_uuid]    = Queue.Queue()

                  poller.register(self.clients[peer_uuid],zmq.POLLIN)

      if len(self.clients.keys()) == 0: continue

      peer_keys = self.clients.keys()
      random.shuffle(peer_keys)
      for peer_uuid in peer_keys:
        #self.set_status("Socket Write Possible:" + peer_uuid)

        #high priority queue processing
        with taco.globals.high_priority_output_queue_lock:
          while not taco.globals.high_priority_output_queue[peer_uuid].empty():
            self.set_status("high priority output q not empty:" + peer_uuid)
            data = taco.globals.high_priority_output_queue[peer_uuid].get()
            self.clients[peer_uuid].send_multipart(['',data])
            self.sleep.set()
            with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #medium priority queue processing
        with taco.globals.medium_priority_output_queue_lock:
          while not taco.globals.medium_priority_output_queue[peer_uuid].empty():
            self.set_status("medium priority output q not empty:" + peer_uuid)
            data = taco.globals.medium_priority_output_queue[peer_uuid].get()
            self.clients[peer_uuid].send_multipart(['',data])
            self.sleep.set()
            with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #filereq q, aka the download throttle 
        if time.time() >= self.file_request_time:
          self.file_request_time = time.time() 
          with taco.globals.file_request_output_queue_lock:
            if not taco.globals.file_request_output_queue[peer_uuid].empty():
              with taco.globals.download_limiter_lock: download_rate = taco.globals.download_limiter.get_rate()

              bw_percent = download_rate / self.max_download_rate
              wait_time = self.chunk_request_rate * bw_percent
              #self.set_status(str((download_rate,self.max_download_rate,self.chunk_request_rate,bw_percent,wait_time)))
              if wait_time > 0.01: self.file_request_time += wait_time

              if download_rate < self.max_download_rate:
                self.set_status("filereq output q not empty+free bw:" + peer_uuid)
                data = taco.globals.file_request_output_queue[peer_uuid].get()
                self.clients[peer_uuid].send_multipart(['',data])
                self.sleep.set()
                with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #low priority queue processing
        with taco.globals.low_priority_output_queue_lock:
          if not taco.globals.low_priority_output_queue[peer_uuid].empty():
            with taco.globals.upload_limiter_lock: upload_rate = taco.globals.upload_limiter.get_rate()
            if upload_rate < self.max_upload_rate:
              self.set_status("low priority output q not empty+free bw:" + peer_uuid)
              data = taco.globals.low_priority_output_queue[peer_uuid].get()
              self.clients[peer_uuid].send_multipart(['',data])
              self.sleep.set()
              with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))

        #rollcall special case
        if self.next_rollcall[peer_uuid] < time.time():
          #self.set_status("Requesting Rollcall from: " + peer_uuid)
          data = taco.commands.Request_Rollcall()
          self.clients[peer_uuid].send_multipart(['',data])
          with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(data))
          self.next_rollcall[peer_uuid] = time.time() + random.randint(taco.constants.ROLLCALL_MIN,taco.constants.ROLLCALL_MAX)
          self.sleep.set()
          #continue

        #RECEIVE BLOCK
        socks = dict(poller.poll(0))
        while self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLIN:
          #self.set_status("Socket Read Possible")
          sink,data = self.clients[peer_uuid].recv_multipart()
          with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data))
          self.set_client_last_reply(peer_uuid)
          self.next_request = taco.commands.Process_Reply(peer_uuid,data)
          if self.next_request != "":
            with taco.globals.medium_priority_output_queue_lock:
              taco.globals.medium_priority_output_queue[peer_uuid].put(self.next_request)
          self.sleep.set()
          socks = dict(poller.poll(0))

        #cleanup block
        self.error_msg = []
        if self.clients[peer_uuid] in socks and socks[self.clients[peer_uuid]] == zmq.POLLERR: self.error_msg.append("got a socket error")
        if abs(self.client_timeout[peer_uuid] - time.time()) > taco.constants.ROLLCALL_TIMEOUT: self.error_msg.append("havn't seen communications")

        if len(self.error_msg) > 0:
          self.set_status("Stopping client: " + peer_uuid + " -- " + " and ".join(self.error_msg),2)
          poller.unregister(self.clients[peer_uuid])
          self.clients[peer_uuid].close(0)
          del self.clients[peer_uuid]          
          del self.client_timeout[peer_uuid]
          with taco.globals.high_priority_output_queue_lock:    del taco.globals.high_priority_output_queue[peer_uuid]
          with taco.globals.medium_priority_output_queue_lock:  del taco.globals.medium_priority_output_queue[peer_uuid]
          with taco.globals.low_priority_output_queue_lock:     del taco.globals.low_priority_output_queue[peer_uuid]
          with taco.globals.file_request_output_queue_lock:     del taco.globals.file_request_output_queue[peer_uuid]
          self.client_reconnect_mod[peer_uuid] = min(self.client_reconnect_mod[peer_uuid] + taco.constants.CLIENT_RECONNECT_MOD,taco.constants.CLIENT_RECONNECT_MAX)
          self.client_connect_time[peer_uuid] = time.time() + self.client_reconnect_mod[peer_uuid]
          

        
    self.set_status("Terminating Clients")
    for peer_uuid in self.clients.keys():
      self.clients[peer_uuid].close(0)
    self.set_status("Stopping zmq ThreadedAuthenticator")
    clientauth.stop() 
    clientctx.term()
    self.set_status("Clients Exit")    
예제 #26
0
class ZmqListener:
    def __init__(self, settings):

        self.redis = RedisScraper(settings)
        self.id = settings.getKey("box_id")
        self.log = logging.getLogger('ZMQ')
        self.clientPath = settings.getKey("zmq.private_cert")
        self.serverPath = settings.getKey("zmq.server_cert")
        if not self.clientPath or not self.serverPath:
            self.log.fatal(
                "zmq certificates not configured in the settings file")
            os._exit(1)

        self.host = settings.getKey("zmq.acq_host")

        self.ctx = zmq.Context()

        self.auth = ThreadAuthenticator(self.ctx)
        self.auth.start()

        #self.auth.allow('127.0.0.1')
        self.auth.configure_curve(domain='*',
                                  location=zmq.auth.CURVE_ALLOW_ANY)

        self.client = self.ctx.socket(zmq.REP)

        try:
            client_public, client_secret = zmq.auth.load_certificate(
                self.clientPath)
            self.client.curve_secretkey = client_secret
            self.client.curve_publickey = client_public

            server_public, _ = zmq.auth.load_certificate(self.serverPath)

            self.client.curve_serverkey = server_public
            self.client.connect(self.host)
        except IOError:
            self.log.fatal("Could not load client certificate")
            os._exit(1)
        except ValueError:
            self.log.fatal("Could not load client certificate")
            os._exit(1)

        self.log.info("ZMQ connected to " + self.host + " using certs " +
                      self.clientPath)
        self.running = False
        self.handlers = {
            opq_pb2.RequestDataMessage.PING: self.ping,
            opq_pb2.RequestDataMessage.READ: self.read
        }

    def ping(self, message):
        message.type = opq_pb2.RequestDataMessage.PONG
        message_buff = message.SerializeToString()
        self.log.info("Received a PING from server")
        self.client.send(message_buff)
        return True

    def read(self, message):
        self.log.debug("Received a data transfer request from server")
        try:
            if message.front == 0 or message.back == 0:
                message.type = opq_pb2.RequestDataMessage.ERROR
                message_buff = message.SerializeToString()
                self.log.info("Bad message from server")
                self.client.send(message_buff)
                return False
            cycles = self.redis.getRange(message.time - message.back,
                                         message.time + message.front)
            cycles.id = self.id
            cycles.mid = message.mid
            message_buff = cycles.SerializeToString()
            self.client.send(message_buff)
        except google.protobuf.message.DecodeError:
            self.log.fatal("Bad request from acquisition server.")
            return False

    def run(self):
        self.running = True
        try:
            while self.running:
                message_buff = self.client.recv()
                message = opq_pb2.RequestDataMessage()
                message.ParseFromString(message_buff)
                self.handlers[message.type](message)
        except google.protobuf.message.DecodeError:
            self.log.fatal("Bad request from acquisition server.")
예제 #27
0
class Command(LAVADaemonCommand):
    help = "LAVA log recorder"
    logger = None
    default_logfile = "/var/log/lava-server/lava-logs.log"

    def __init__(self, *args, **options):
        super().__init__(*args, **options)
        self.logger = logging.getLogger("lava-logs")
        self.log_socket = None
        self.auth = None
        self.controler = None
        self.inotify_fd = None
        self.pipe_r = None
        self.poller = None
        self.cert_dir_path = None
        # List of logs
        self.jobs = {}
        # Keep test cases in memory
        self.test_cases = []
        # Master status
        self.last_ping = 0
        self.ping_interval = TIMEOUT

    def add_arguments(self, parser):
        super().add_arguments(parser)

        net = parser.add_argument_group("network")
        net.add_argument('--socket',
                         default='tcp://*:5555',
                         help="Socket waiting for logs. Default: tcp://*:5555")
        net.add_argument('--master-socket',
                         default='tcp://localhost:5556',
                         help="Socket for master-slave communication. Default: tcp://localhost:5556")
        net.add_argument('--ipv6', default=False, action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt', default=False, action='store_true',
                         help="Encrypt messages")
        net.add_argument('--master-cert',
                         default='/etc/lava-dispatcher/certificates.d/master.key_secret',
                         help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-logs", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        filename = os.path.join(settings.MEDIA_ROOT, 'lava-logs-config.yaml')
        self.logger.debug("[INIT] Dumping config to %s", filename)
        with open(filename, 'w') as output:
            yaml.dump(options, output)

        # Create the sockets
        context = zmq.Context()
        self.log_socket = context.socket(zmq.PULL)
        self.controler = context.socket(zmq.ROUTER)
        self.controler.setsockopt(zmq.IDENTITY, b"lava-logs")
        # Limit the number of messages in the queue
        self.controler.setsockopt(zmq.SNDHWM, 2)
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5
        # "Immediately readies that connection for data transfer with the master"
        self.controler.setsockopt(zmq.CONNECT_RID, b"master")

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.log_socket.setsockopt(zmq.IPV6, 1)
            self.controler.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs'])
                self.auth.configure_curve(domain='*', location=options['slaves_certs'])
            except OSError as err:
                self.logger.error("[INIT] %s", err)
                self.auth.stop()
                return
            self.log_socket.curve_publickey = master_public
            self.log_socket.curve_secretkey = master_secret
            self.log_socket.curve_server = True
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_serverkey = master_public

        self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
        self.cert_dir_path = options["slaves_certs"]
        self.inotify_fd = watch_directory(options["slaves_certs"])
        if self.inotify_fd is None:
            self.logger.error("[INIT] Unable to start inotify")

        self.log_socket.bind(options['socket'])
        self.controler.connect(options['master_socket'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.log_socket, zmq.POLLIN)
        self.poller.register(self.controler, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] listening for logs")
        # PING right now: the master is waiting for this message to start
        # scheduling.
        self.controler.send_multipart([b"master", b"PING"])

        try:
            self.main_loop()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)

        # Close the controler socket
        self.controler.close(linger=0)
        self.poller.unregister(self.controler)

        # Carefully close the logging socket as we don't want to lose messages
        self.logger.info("[EXIT] Disconnect logging socket and process messages")
        endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT))
        self.logger.debug("[EXIT] unbinding from '%s'", endpoint)
        self.log_socket.unbind(endpoint)

        # Empty the queue
        try:
            while self.wait_for_messages(True):
                # Flush test cases cache for every iteration because we might
                # get killed soon.
                self.flush_test_cases()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Last flush
            self.flush_test_cases()
            self.logger.info("[EXIT] Closing the logging socket: the queue is empty")
            self.log_socket.close()
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def flush_test_cases(self):
        if not self.test_cases:
            return

        # Try to save into the database
        try:
            TestCase.objects.bulk_create(self.test_cases)
            self.logger.info("Saving %d test cases", len(self.test_cases))
            self.test_cases = []
        except DatabaseError as exc:
            self.logger.error("Unable to flush the test cases")
            self.logger.exception(exc)
            self.logger.warning("Saving test cases one by one and dropping the faulty ones")
            saved = 0
            for tc in self.test_cases:
                with contextlib.suppress(DatabaseError):
                    tc.save()
                    saved += 1
            self.logger.info("%d test cases saved, %d dropped", saved, len(self.test_cases) - saved)
            self.test_cases = []

    def main_loop(self):
        last_gc = time.time()
        last_bulk_create = time.time()

        # Wait for messages
        # TODO: fix timeout computation
        while self.wait_for_messages(False):
            now = time.time()

            # Dump TestCase into the database
            if now - last_bulk_create > BULK_CREATE_TIMEOUT:
                last_bulk_create = now
                self.flush_test_cases()

            # Close old file handlers
            if now - last_gc > FD_TIMEOUT:
                last_gc = now
                # Iterate while removing keys is not compatible with iterator
                for job_id in list(self.jobs.keys()):  # pylint: disable=consider-iterating-dictionary
                    if now - self.jobs[job_id].last_usage > FD_TIMEOUT:
                        self.logger.info("[%s] closing log file", job_id)
                        self.jobs[job_id].close()
                        del self.jobs[job_id]

            # Ping the master
            if now - self.last_ping > self.ping_interval:
                self.logger.debug("PING => master")
                self.last_ping = now
                self.controler.send_multipart([b"master", b"PING"])

    def wait_for_messages(self, leaving):
        try:
            try:
                sockets = dict(self.poller.poll(TIMEOUT * 1000))
            except zmq.error.ZMQError as exc:
                self.logger.error("[POLL] zmq error: %s", str(exc))
                return True

            # Messages
            if sockets.get(self.log_socket) == zmq.POLLIN:
                self.logging_socket()
                return True

            # Signals
            elif sockets.get(self.pipe_r) == zmq.POLLIN:
                # remove the message from the queue
                os.read(self.pipe_r, 1)

                if not leaving:
                    self.logger.info("[POLL] received a signal, leaving")
                    return False
                else:
                    self.logger.warning("[POLL] signal already handled, please wait for the process to exit")
                    return True

            # Pong received
            elif sockets.get(self.controler) == zmq.POLLIN:
                self.controler_socket()
                return True

            # Inotify socket
            if sockets.get(self.inotify_fd) == zmq.POLLIN:
                os.read(self.inotify_fd, 4096)
                self.logger.debug("[AUTH] Reloading certificates from %s",
                                  self.cert_dir_path)
                self.auth.configure_curve(domain='*',
                                          location=self.cert_dir_path)

            # Nothing received
            else:
                return not leaving

        except (OperationalError, InterfaceError):
            self.logger.info("[RESET] database connection reset")
            connection.close()
        return True

    def logging_socket(self):
        msg = self.log_socket.recv_multipart()
        try:
            (job_id, message) = (u(m) for m in msg)  # pylint: disable=unbalanced-tuple-unpacking
        except ValueError:
            # do not let a bad message stop the master.
            self.logger.error("[POLL] failed to parse log message, skipping: %s", msg)
            return

        try:
            scanned = yaml.load(message, Loader=yaml.CLoader)
        except yaml.YAMLError:
            self.logger.error("[%s] data are not valid YAML, dropping", job_id)
            return

        # Look for "results" level
        try:
            message_lvl = scanned["lvl"]
            message_msg = scanned["msg"]
        except TypeError:
            self.logger.error("[%s] not a dictionary, dropping", job_id)
            return
        except KeyError:
            self.logger.error(
                "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s",
                job_id, message)
            return

        # Find the handler (if available)
        if job_id not in self.jobs:
            # Query the database for the job
            try:
                job = TestJob.objects.get(id=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return

            self.logger.info("[%s] receiving logs from a new job", job_id)
            # Create the sub directories (if needed)
            mkdir(job.output_dir)
            self.jobs[job_id] = JobHandler(job)

        # For 'event', send an event and log as 'debug'
        if message_lvl == 'event':
            self.logger.debug("[%s] event: %s", job_id, message_msg)
            send_event(".event", "lavaserver", {"message": message_msg, "job": job_id})
            message_lvl = "debug"
        # For 'marker', save in the database and log as 'debug'
        elif message_lvl == 'marker':
            # TODO: save on the file system in case of lava-logs restart
            m_type = message_msg.get("type")
            case = message_msg.get("case")
            if m_type is None or case is None:
                self.logger.error("[%s] invalid marker: %s", job_id, message_msg)
                return
            self.jobs[job_id].markers.setdefault(case, {})[m_type] = self.jobs[job_id].line_count()
            # This is in fact the previous line
            self.jobs[job_id].markers[case][m_type] -= 1
            self.logger.debug("[%s] marker: %s line: %s", job_id, message_msg, self.jobs[job_id].markers[case][m_type])
            return

        # Mark the file handler as used
        self.jobs[job_id].last_usage = time.time()
        # The format is a list of dictionaries
        self.jobs[job_id].write("- %s" % message)

        if message_lvl == "results":
            try:
                job = TestJob.objects.get(pk=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return
            meta_filename = create_metadata_store(message_msg, job)
            new_test_case = map_scanned_results(results=message_msg, job=job,
                                                markers=self.jobs[job_id].markers,
                                                meta_filename=meta_filename)

            if new_test_case is None:
                self.logger.warning(
                    "[%s] unable to map scanned results: %s",
                    job_id, message)
            else:
                self.test_cases.append(new_test_case)

            # Look for lava.job result
            if message_msg.get("definition") == "lava" and message_msg.get("case") == "job":
                # Flush cached test cases
                self.flush_test_cases()

                if message_msg.get("result") == "pass":
                    health = TestJob.HEALTH_COMPLETE
                    health_msg = "Complete"
                else:
                    health = TestJob.HEALTH_INCOMPLETE
                    health_msg = "Incomplete"
                self.logger.info("[%s] job status: %s", job_id, health_msg)

                infrastructure_error = (message_msg.get("error_type") in ["Bug",
                                                                          "Configuration",
                                                                          "Infrastructure"])
                if infrastructure_error:
                    self.logger.info("[%s] Infrastructure error", job_id)

                # Update status.
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)
                    job.go_state_finished(health, infrastructure_error)
                    job.save()

        # n.b. logging here would produce a log entry for every message in every job.

    def controler_socket(self):
        msg = self.controler.recv_multipart()
        try:
            master_id = u(msg[0])
            action = u(msg[1])
            ping_interval = int(msg[2])

            if master_id != "master":
                self.logger.error("Invalid master id '%s'. Should be 'master'",
                                  master_id)
                return
            if action != "PONG":
                self.logger.error("Invalid answer '%s'. Should be 'PONG'",
                                  action)
                return
        except (IndexError, ValueError):
            self.logger.error("Invalid message '%s'", msg)
            return

        if ping_interval < TIMEOUT:
            self.logger.error("invalid ping interval (%d) too small", ping_interval)
            return

        self.logger.debug("master => PONG(%d)", ping_interval)
        self.ping_interval = ping_interval
예제 #28
0
class Driver(drivers.BaseDriver):
    def __init__(
        self,
        args,
        encrypted_traffic_data=None,
        interface=None,
    ):
        """Initialize the Driver.

        :param args: Arguments parsed by argparse.
        :type args: Object
        :param encrypted_traffic: Enable|Disable encrypted traffic.
        :type encrypted_traffic: Boolean
        :param interface: The interface instance (client/server)
        :type interface: Object
        """

        self.thread_processor = multiprocessing.Process
        self.event = multiprocessing.Event()
        self.semaphore = multiprocessing.Semaphore
        self.flushqueue = _FlushQueue
        self.args = args
        if getattr(self.args, "zmq_generate_keys", False) is True:
            self._generate_certificates()
            print("New certificates generated")
            raise SystemExit(0)

        self.encrypted_traffic_data = encrypted_traffic_data

        mode = getattr(self.args, "mode", None)
        if mode == "client":
            self.bind_address = self.args.zmq_server_address
        elif mode == "server":
            self.bind_address = self.args.zmq_bind_address
        else:
            self.bind_address = "*"
        self.proto = "tcp"
        self.connection_string = "{proto}://{addr}".format(
            proto=self.proto, addr=self.bind_address)

        if self.encrypted_traffic_data:
            self.encrypted_traffic = self.encrypted_traffic_data.get("enabled")
            self.secret_keys_dir = self.encrypted_traffic_data.get(
                "secret_keys_dir")
            self.public_keys_dir = self.encrypted_traffic_data.get(
                "public_keys_dir")
        else:
            self.encrypted_traffic = False
            self.secret_keys_dir = None
            self.public_keys_dir = None

        self._context = zmq.Context()
        self.ctx = self._context.instance()
        self.poller = zmq.Poller()
        self.interface = interface
        super(Driver, self).__init__(
            args=args,
            encrypted_traffic_data=self.encrypted_traffic_data,
            interface=interface,
        )
        self.bind_job = None
        self.bind_backend = None
        self.hwm = getattr(self.args, "zmq_highwater_mark", 1024)

    def __copy__(self):
        """Return a new copy of the driver."""

        return Driver(
            args=self.args,
            encrypted_traffic_data=self.encrypted_traffic_data,
            interface=self.interface,
        )

    def _backend_bind(self):
        """Bind an address to a backend socket and return the socket.

        :returns: Object
        """

        bind = self._socket_bind(
            socket_type=zmq.ROUTER,
            connection=self.connection_string,
            port=self.args.backend_port,
        )
        bind.set_hwm(self.hwm)
        self.log.debug(
            "Identity [ %s ] backend connect hwm state [ %s ]",
            self.identity,
            bind.get_hwm(),
        )
        return bind

    def _backend_connect(self):
        """Connect to a backend socket and return the socket.

        :returns: Object
        """

        self.log.debug("Establishing backend connection.")
        bind = self._socket_connect(
            socket_type=zmq.DEALER,
            connection=self.connection_string,
            port=self.args.backend_port,
        )
        bind.set_hwm(self.hwm)
        self.log.debug(
            "Identity [ %s ] backend connect hwm state [ %s ]",
            self.identity,
            bind.get_hwm(),
        )
        return bind

    def _bind_check(self, bind, interval=1, constant=1000):
        """Return True if a bind type contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        socks = dict(self.poller.poll(interval * constant))
        if socks.get(bind) == zmq.POLLIN:
            return True
        else:
            return False

    def _close(self, socket):
        if socket is None:
            return

        try:
            socket.close(linger=2)
            close_time = time.time()
            while not socket.closed:
                if time.time() - close_time > 60:
                    raise TimeoutError(
                        "Job [ {} ] failed to close transfer socket".format(
                            self.job_id))
                else:
                    socket.close(linger=2)
                    time.sleep(1)
        except Exception as e:
            self.log.error(
                "Ran into an exception while closing the socket %s",
                str(e),
            )
        else:
            self.log.debug("Backend socket closed")

    def _generate_certificates(self, base_dir="/etc/directord"):
        """Generate client and server CURVE certificate files.

        :param base_dir: Directord configuration path.
        :type base_dir: String
        """

        keys_dir = os.path.join(base_dir, "certificates")
        public_keys_dir = os.path.join(base_dir, "public_keys")
        secret_keys_dir = os.path.join(base_dir, "private_keys")

        for item in [keys_dir, public_keys_dir, secret_keys_dir]:
            os.makedirs(item, exist_ok=True)

        # Run certificate backup
        self._move_certificates(directory=public_keys_dir, backup=True)
        self._move_certificates(directory=secret_keys_dir,
                                backup=True,
                                suffix=".key_secret")

        # create new keys in certificates dir
        for item in ["server", "client"]:
            self._key_generate(keys_dir=keys_dir, key_type=item)

        # Move generated certificates in place
        self._move_certificates(
            directory=keys_dir,
            target_directory=public_keys_dir,
            suffix=".key",
        )
        self._move_certificates(
            directory=keys_dir,
            target_directory=secret_keys_dir,
            suffix=".key_secret",
        )

    def _job_bind(self):
        """Bind an address to a job socket and return the socket.

        :returns: Object
        """

        return self._socket_bind(
            socket_type=zmq.ROUTER,
            connection=self.connection_string,
            port=self.args.job_port,
        )

    def _job_connect(self):
        """Connect to a job socket and return the socket.

        :returns: Object
        """

        self.log.debug("Establishing Job connection.")
        return self._socket_connect(
            socket_type=zmq.DEALER,
            connection=self.connection_string,
            port=self.args.job_port,
        )

    def _key_generate(self, keys_dir, key_type):
        """Generate certificate.

        :param keys_dir: Full Directory path where a given key will be stored.
        :type keys_dir: String
        :param key_type: Key type to be generated.
        :type key_type: String
        """

        zmq_auth.create_certificates(keys_dir, key_type)

    @staticmethod
    def _move_certificates(directory,
                           target_directory=None,
                           backup=False,
                           suffix=".key"):
        """Move certificates when required.

        :param directory: Set the origin path.
        :type directory: String
        :param target_directory: Set the target path.
        :type target_directory: String
        :param backup: Enable file backup before moving.
        :type backup:  Boolean
        :param suffix: Set the search suffix
        :type suffix: String
        """

        for item in os.listdir(directory):
            if backup:
                target_file = "{}.bak".format(os.path.basename(item))
            else:
                target_file = os.path.basename(item)

            if item.endswith(suffix):
                os.rename(
                    os.path.join(directory, item),
                    os.path.join(target_directory or directory, target_file),
                )

    def _socket_bind(self, socket_type, connection, port, poller_type=None):
        """Return a socket object which has been bound to a given address.

        When the socket_type is not PUB or PUSH, the bound socket will also be
        registered with self.poller as defined within the Interface
        class.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :param connection: Set the Address information used for the bound
                           socket.
        :type connection: String
        :param port: Define the port which the socket will be bound to.
        :type port: Integer
        :param poller_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type poller_type: Integer
        :returns: Object
        """

        if poller_type is None:
            poller_type = zmq.POLLIN

        bind = self._socket_context(socket_type=socket_type)
        auth_enabled = (self.args.zmq_shared_key
                        or self.args.zmq_curve_encryption)

        if auth_enabled:
            self.auth = ThreadAuthenticator(self.ctx, log=self.log)
            self.auth.start()
            self.auth.allow()

            if self.args.zmq_shared_key:
                # Enables basic auth
                self.auth.configure_plain(
                    domain="*", passwords={"admin": self.args.zmq_shared_key})
                bind.plain_server = True  # Enable shared key authentication
                self.log.info("Shared key authentication enabled.")
            elif self.args.zmq_curve_encryption:
                server_secret_file = os.path.join(self.secret_keys_dir,
                                                  "server.key_secret")
                for item in [
                        self.public_keys_dir,
                        self.secret_keys_dir,
                        server_secret_file,
                ]:
                    if not os.path.exists(item):
                        raise SystemExit(
                            "The required path [ {} ] does not exist. Have"
                            " you generated your keys?".format(item))
                self.auth.configure_curve(domain="*",
                                          location=self.public_keys_dir)
                try:
                    server_public, server_secret = zmq_auth.load_certificate(
                        server_secret_file)
                except OSError as e:
                    self.log.error(
                        "Failed to load certificates: %s, Configuration: %s",
                        str(e),
                        vars(self.args),
                    )
                    raise SystemExit("Failed to load certificates")
                else:
                    bind.curve_secretkey = server_secret
                    bind.curve_publickey = server_public
                    bind.curve_server = True  # Enable curve authentication
        bind.bind("{connection}:{port}".format(
            connection=connection,
            port=port,
        ))

        if socket_type not in [zmq.PUB]:
            self.poller.register(bind, poller_type)

        return bind

    def _socket_connect(self, socket_type, connection, port, poller_type=None):
        """Return a socket object which has been bound to a given address.

        > A connection back to the server will wait 10 seconds for an ack
          before going into a retry loop. This is done to forcefully cycle
          the connection object to reset.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :param connection: Set the Address information used for the bound
                           socket.
        :type connection: String
        :param port: Define the port which the socket will be bound to.
        :type port: Integer
        :param poller_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type poller_type: Integer
        :returns: Object
        """

        if poller_type is None:
            poller_type = zmq.POLLIN

        bind = self._socket_context(socket_type=socket_type)

        if self.args.zmq_shared_key:
            bind.plain_username = b"admin"  # User is hard coded.
            bind.plain_password = self.args.zmq_shared_key.encode()
            self.log.info("Shared key authentication enabled.")
        elif self.args.zmq_curve_encryption:
            client_secret_file = os.path.join(self.secret_keys_dir,
                                              "client.key_secret")
            server_public_file = os.path.join(self.public_keys_dir,
                                              "server.key")
            for item in [
                    self.public_keys_dir,
                    self.secret_keys_dir,
                    client_secret_file,
                    server_public_file,
            ]:
                if not os.path.exists(item):
                    raise SystemExit(
                        "The required path [ {} ] does not exist. Have"
                        " you generated your keys?".format(item))
            try:
                client_public, client_secret = zmq_auth.load_certificate(
                    client_secret_file)
                server_public, _ = zmq_auth.load_certificate(
                    server_public_file)
            except OSError as e:
                self.log.error(
                    "Error while loading certificates: %s. Configuration: %s",
                    str(e),
                    vars(self.args),
                )
                raise SystemExit("Failed to load keys.")
            else:
                bind.curve_secretkey = client_secret
                bind.curve_publickey = client_public
                bind.curve_serverkey = server_public

        if socket_type == zmq.SUB:
            bind.setsockopt_string(zmq.SUBSCRIBE, self.identity)
        else:
            bind.setsockopt_string(zmq.IDENTITY, self.identity)

        self.poller.register(bind, poller_type)
        bind.connect("{connection}:{port}".format(
            connection=connection,
            port=port,
        ))

        self.log.info("Socket connected to [ %s ].", connection)
        return bind

    def _socket_context(self, socket_type):
        """Create socket context and return a bind object.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :returns: Object
        """

        bind = self.ctx.socket(socket_type)
        bind.linger = getattr(self.args, "heartbeat_interval", 60)
        hwm = int(self.hwm * 4)
        try:
            bind.sndhwm = bind.rcvhwm = hwm
        except AttributeError:
            bind.hwm = hwm

        bind.set_hwm(hwm)
        bind.setsockopt(zmq.SNDHWM, hwm)
        bind.setsockopt(zmq.RCVHWM, hwm)
        if socket_type == zmq.ROUTER:
            bind.setsockopt(zmq.ROUTER_MANDATORY, 1)

        return bind

    @staticmethod
    def _socket_recv(socket, nonblocking=False):
        """Receive a message over a ZM0 socket.

        The message specification for server is as follows.

            [
                b"Identity"
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        The message specification for client is as follows.

            [
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        All message parts are byte encoded.

        All possible control characters are defined within the Interface class.
        For more on control characters review the following
        URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl).

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        """

        if nonblocking:
            flags = zmq.NOBLOCK
        else:
            flags = 0

        return socket.recv_multipart(flags=flags)

    @tenacity.retry(
        retry=tenacity.retry_if_exception_type(Exception),
        wait=tenacity.wait_fixed(5),
        before_sleep=tenacity.before_sleep_log(
            logger.getLogger(name="directord"), logging.WARN),
    )
    def _socket_send(
        self,
        socket,
        identity=None,
        msg_id=None,
        control=None,
        command=None,
        data=None,
        info=None,
        stderr=None,
        stdout=None,
        nonblocking=False,
    ):
        """Send a message over a ZM0 socket.

        The message specification for server is as follows.

            [
                b"Identity"
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        The message specification for client is as follows.

            [
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        All message information is assumed to be byte encoded.

        All possible control characters are defined within the Interface class.
        For more on control characters review the following
        URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl).

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param identity: Target where message will be sent.
        :type identity: Bytes
        :param msg_id: ID information for a given message. If no ID is
                       provided a UUID will be generated.
        :type msg_id: Bytes
        :param control: ASCII control charaters.
        :type control: Bytes
        :param command: Command definition for a given message.
        :type command: Bytes
        :param data: Encoded data that will be transmitted.
        :type data: Bytes
        :param info: Encoded information that will be transmitted.
        :type info: Bytes
        :param stderr: Encoded error information from a command.
        :type stderr: Bytes
        :param stdout: Encoded output information from a command.
        :type stdout: Bytes
        :param nonblocking: Enable non-blocking send.
        :type nonblocking: Boolean
        :returns: Object
        """
        def _encoder(item):
            try:
                return item.encode()
            except AttributeError:
                return item

        if not msg_id:
            msg_id = utils.get_uuid()

        if not control:
            control = self.nullbyte

        if not command:
            command = self.nullbyte

        if not data:
            data = self.nullbyte

        if not info:
            info = self.nullbyte

        if not stderr:
            stderr = self.nullbyte

        if not stdout:
            stdout = self.nullbyte

        message_parts = [msg_id, control, command, data, info, stderr, stdout]

        if identity:
            message_parts.insert(0, identity)

        message_parts = [_encoder(i) for i in message_parts]

        if nonblocking:
            flags = zmq.NOBLOCK
        else:
            flags = 0

        try:
            return socket.send_multipart(message_parts, flags=flags)
        except Exception as e:
            self.log.warn("Failed to send message to [ %s ]", identity)
            raise e

    def _recv(self, socket, nonblocking=False):
        """Receive message.

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        recv_obj = self._socket_recv(socket=socket, nonblocking=nonblocking)
        return tuple([i.decode() for i in recv_obj])

    def backend_recv(self, nonblocking=False):
        """Receive a transfer message.

        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        return self._recv(socket=self.bind_backend, nonblocking=nonblocking)

    def backend_init(self):
        """Initialize the backend socket.

        For server mode, this is a bound local socket.
        For client mode, it is a connection to the server socket.

        :returns: Object
        """

        if self.args.mode == "server":
            self.bind_backend = self._backend_bind()
        else:
            self.bind_backend = self._backend_connect()

    def backend_close(self):
        """Close the backend socket."""

        self._close(socket=self.bind_backend)

    def backend_check(self, interval=1, constant=1000):
        """Return True if the backend contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        return self._bind_check(bind=self.bind_backend,
                                interval=interval,
                                constant=constant)

    def backend_send(self, *args, **kwargs):
        """Send a job message.

        * All args and kwargs are passed through to the socket send.

        :returns: Object
        """

        kwargs["socket"] = self.bind_backend
        return self._socket_send(*args, **kwargs)

    @staticmethod
    def get_lock():
        """Returns a thread lock."""

        return multiprocessing.Lock()

    def heartbeat_send(self,
                       host_uptime=None,
                       agent_uptime=None,
                       version=None,
                       driver=None):
        """Send a heartbeat.

        :param host_uptime: Sender uptime
        :type host_uptime: String
        :param agent_uptime: Sender agent uptime
        :type agent_uptime: String
        :param version: Sender directord version
        :type version: String
        :param version: Driver information
        :type version: String
        """

        job_id = utils.get_uuid()
        self.log.info(
            "Job [ %s ] sending heartbeat from [ %s ] to server",
            job_id,
            self.identity,
        )

        return self.job_send(
            control=self.heartbeat_notice,
            msg_id=job_id,
            data=json.dumps({
                "job_id": job_id,
                "version": version,
                "host_uptime": host_uptime,
                "agent_uptime": agent_uptime,
                "machine_id": self.machine_id,
                "driver": driver,
            }),
        )

    def job_send(self, *args, **kwargs):
        """Send a job message.

        * All args and kwargs are passed through to the socket send.

        :returns: Object
        """

        kwargs["socket"] = self.bind_job
        return self._socket_send(*args, **kwargs)

    def job_recv(self, nonblocking=False):
        """Receive a transfer message.

        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        return self._recv(socket=self.bind_job, nonblocking=nonblocking)

    def job_init(self):
        """Initialize the job socket.

        For server mode, this is a bound local socket.
        For client mode, it is a connection to the server socket.

        :returns: Object
        """

        if self.args.mode == "server":
            self.bind_job = self._job_bind()
        else:
            self.bind_job = self._job_connect()

    def job_close(self):
        """Close the job socket."""

        self._close(socket=self.bind_job)

    def job_check(self, interval=1, constant=1000):
        """Return True if a job contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        return self._bind_check(bind=self.bind_job,
                                interval=interval,
                                constant=constant)

    def shutdown(self):
        """Shutdown the driver."""

        if hasattr(self.ctx, "close"):
            self.ctx.close()

        if hasattr(self._context, "close"):
            self._context.close()

        self.job_close()
        self.backend_close()
예제 #29
0
class dataTransfer():
    def __init__ (self, connectionType, signalHost = None, useLog = False, context = None):

        if useLog:
            self.log = logging.getLogger("dataTransferAPI")
        elif useLog == None:
            self.log = noLoggingFunction()
        else:
            self.log = loggingFunction()

        # ZMQ applications always start by creating a context,
        # and then using that for creating sockets
        # (source: ZeroMQ, Messaging for Many Applications by Pieter Hintjens)
        if context:
            self.context    = context
            self.extContext = True
        else:
            self.context    = zmq.Context()
            self.extContext = False


        self.signalHost            = signalHost
        self.signalPort            = "50000"
        self.requestPort           = "50001"
        self.dataHost              = None
        self.dataPort              = None

        self.signalSocket          = None
        self.dataSocket            = None
        self.requestSocket         = None

        self.poller                = zmq.Poller()

        self.auth                  = None

        self.targets               = None

        self.supportedConnections = ["stream", "streamMetadata", "queryNext", "queryMetadata"]

        self.signalExchanged       = None

        self.streamStarted         = None
        self.queryNextStarted      = None

        self.socketResponseTimeout = 1000

        if connectionType in self.supportedConnections:
            self.connectionType = connectionType
        else:
            raise NotSupported("Chosen type of connection is not supported.")


    # targets: [host, port, prio] or [[host, port, prio], ...]
    def initiate (self, targets):

        if type(targets) != list:
            self.stop()
            raise FormatError("Argument 'targets' must be list.")

        if not self.context:
            self.context    = zmq.Context()
            self.extContext = False

        signal = None
        # Signal exchange
        if self.connectionType == "stream":
            signalPort = self.signalPort
            signal     = "START_STREAM"
        elif self.connectionType == "streamMetadata":
            signalPort = self.signalPort
            signal     = "START_STREAM_METADATA"
        elif self.connectionType == "queryNext":
            signalPort = self.signalPort
            signal     = "START_QUERY_NEXT"
        elif self.connectionType == "queryMetadata":
            signalPort = self.signalPort
            signal     = "START_QUERY_METADATA"

        self.log.debug("Create socket for signal exchange...")


        if self.signalHost:
            self.__createSignalSocket(signalPort)
        else:
            self.stop()
            raise ConnectionFailed("No host to send signal to specified." )

        self.__setTargets (targets)

        message = self.__sendSignal(signal)

        if message and message == "VERSION_CONFLICT":
            self.stop()
            raise VersionError("Versions are conflicting.")

        elif message and message == "NO_VALID_HOST":
            self.stop()
            raise AuthenticationFailed("Host is not allowed to connect.")

        elif message and message == "CONNECTION_ALREADY_OPEN":
            self.stop()
            raise CommunicationFailed("Connection is already open.")

        elif message and message == "NO_VALID_SIGNAL":
            self.stop()
            raise CommunicationFailed("Connection type is not supported for this kind of sender.")

        # if there was no response or the response was of the wrong format, the receiver should be shut down
        elif message and message.startswith(signal):
            self.log.info("Received confirmation ...")
            self.signalExchanged = signal

        else:
            raise CommunicationFailed("Sending start signal ...failed.")


    def __createSignalSocket (self, signalPort):

        # To send a notification that a Displayer is up and running, a communication socket is needed
        # create socket to exchange signals with Sender
        self.signalSocket = self.context.socket(zmq.REQ)

        # time to wait for the sender to give a confirmation of the signal
#        self.signalSocket.RCVTIMEO = self.socketResponseTimeout
        connectionStr = "tcp://" + str(self.signalHost) + ":" + str(signalPort)
        try:
            self.signalSocket.connect(connectionStr)
            self.log.info("signalSocket started (connect) for '" + connectionStr + "'")
        except:
            self.log.error("Failed to start signalSocket (connect): '" + connectionStr + "'")
            raise

        # using a Poller to implement the signalSocket timeout (in older ZMQ version there is no option RCVTIMEO)
        self.poller.register(self.signalSocket, zmq.POLLIN)


    def __setTargets (self, targets):
        self.targets = []

        # [host, port, prio]
        if len(targets) == 3 and type(targets[0]) != list and type(targets[1]) != list and type(targets[2]) != list:
            host, port, prio = targets
            self.targets = [[host + ":" + port, prio, [""]]]

        # [host, port, prio, suffixes]
        elif len(targets) == 4 and type(targets[0]) != list and type(targets[1]) != list and type(targets[2]) != list and type(targets[3]) == list:
            host, port, prio, suffixes = targets
            self.targets = [[host + ":" + port, prio, suffixes]]

        # [[host, port, prio], ...] or [[host, port, prio, suffixes], ...]
        else:
            for t in targets:
                if type(t) == list and len(t) == 3:
                    host, port, prio = t
                    self.targets.append([host + ":" + port, prio, [""]])
                elif type(t) == list and len(t) == 4 and type(t[3]):
                    host, port, prio, suffixes = t
                    self.targets.append([host + ":" + port, prio, suffixes])
                else:
                    self.stop()
                    self.log.debug("targets=" + str(targets))
                    raise FormatError("Argument 'targets' is of wrong format.")


    def __sendSignal (self, signal):

        if not signal:
            return

        # Send the signal that the communication infrastructure should be established
        self.log.info("Sending Signal")

        sendMessage = [__version__,  signal]

        trg = cPickle.dumps(self.targets)
        sendMessage.append(trg)

#        sendMessage = [__version__, signal, self.dataHost, self.dataPort]

        self.log.debug("Signal: " + str(sendMessage))
        try:
            self.signalSocket.send_multipart(sendMessage)
        except:
            self.log.error("Could not send signal")
            raise

        message = None
        try:
            socks = dict(self.poller.poll(self.socketResponseTimeout))
        except:
            self.log.error("Could not poll for new message")
            raise


        # if there was a response
        if self.signalSocket in socks and socks[self.signalSocket] == zmq.POLLIN:
            try:
                #  Get the reply.
                message = self.signalSocket.recv()
                self.log.info("Received answer to signal: " + str(message) )

            except:
                self.log.error("Could not receive answer to signal")
                raise

        return message


    def start (self, dataSocket = False, whitelist = None):

        # Receive data only from whitelisted nodes
        if whitelist:
            if type(whitelist) == list:
                self.auth = ThreadAuthenticator(self.context)
                self.auth.start()
                for host in whitelist:
                    try:
                        if host == "localhost":
                            ip = [socket.gethostbyname(host)]
                        else:
                            hostname, tmp, ip = socket.gethostbyaddr(host)

                        self.log.debug("Allowing host " + host + " (" + str(ip[0]) + ")")
                        self.auth.allow(ip[0])
                    except:
                        self.log.error("Error was: ", exc_info=True)
                        raise AuthenticationFailed("Could not get IP of host " + host)
            else:
                raise FormatError("Whitelist has to be a list of IPs")


        socketIdToConnect = self.streamStarted or self.queryNextStarted

        if socketIdToConnect:
            self.log.info("Reopening already started connection.")
        else:

            ip   = "0.0.0.0"           #TODO use IP of hostname?

            host = ""
            port = ""

            if dataSocket:
                if type(dataSocket) == list:
                    socketIdToConnect = dataSocket[0] + ":" + dataSocket[1]
                    host = dataSocket[0]
                    ip   = socket.gethostbyaddr(host)[2][0]
                    port = dataSocket[1]
                else:
                    port = str(dataSocket)

                    host = socket.gethostname()
                    socketId = host + ":" + port
                    ipFromHost = socket.gethostbyaddr(host)[2]
                    if len(ipFromHost) == 1:
                        ip = ipFromHost[0]

            elif len(self.targets) == 1:
                host, port = self.targets[0][0].split(":")
                ipFromHost = socket.gethostbyaddr(host)[2]
                if len(ipFromHost) == 1:
                    ip = ipFromHost[0]

            else:
                raise FormatError("Multipe possible ports. Please choose which one to use.")

            socketId = host + ":" + port
            socketIdToConnect = ip + ":" + port
#            socketIdToConnect = "[" + ip + "]:" + port


        self.dataSocket = self.context.socket(zmq.PULL)
        # An additional socket is needed to establish the data retriving mechanism
        connectionStr = "tcp://" + socketIdToConnect
        if whitelist:
            self.dataSocket.zap_domain = b'global'

        try:
#            self.dataSocket.ipv6 = True
            self.dataSocket.bind(connectionStr)
#            self.dataSocket.bind("tcp://[2003:ce:5bc0:a600:fa16:54ff:fef4:9fc0]:50102")
            self.log.info("Data socket of type " + self.connectionType + " started (bind) for '" + connectionStr + "'")
        except:
            self.log.error("Failed to start Socket of type " + self.connectionType + " (bind): '" + connectionStr + "'", exc_info=True)
            raise

        self.poller.register(self.dataSocket, zmq.POLLIN)

        if self.connectionType in ["queryNext", "queryMetadata"]:

            self.requestSocket = self.context.socket(zmq.PUSH)
            # An additional socket is needed to establish the data retriving mechanism
            connectionStr = "tcp://" + self.signalHost + ":" + self.requestPort
            try:
                self.requestSocket.connect(connectionStr)
                self.log.info("Request socket started (connect) for '" + connectionStr + "'")
            except:
                self.log.error("Failed to start Socket of type " + self.connectionType + " (connect): '" + connectionStr + "'", exc_info=True)
                raise

            self.queryNextStarted = socketId
        else:
            self.streamStarted    = socketId


    ##
    #
    # Receives or queries for new files depending on the connection initialized
    #
    # returns either
    #   the newest file
    #       (if connection type "queryNext" or "stream" was choosen)
    #   the path of the newest file
    #       (if connection type "queryMetadata" or "streamMetadata" was choosen)
    #
    ##
    def get (self, timeout=None):

        if not self.streamStarted and not self.queryNextStarted:
            self.log.info("Could not communicate, no connection was initialized.")
            return None, None

        if self.queryNextStarted :

            sendMessage = ["NEXT", self.queryNextStarted]
            try:
                self.requestSocket.send_multipart(sendMessage)
            except Exception as e:
                self.log.error("Could not send request to requestSocket", exc_info=True)
                return None, None

        while True:
            # receive data
            if timeout:
                try:
                    socks = dict(self.poller.poll(timeout))
                except:
                    self.log.error("Could not poll for new message")
                    raise
            else:
                try:
                    socks = dict(self.poller.poll())
                except:
                    self.log.error("Could not poll for new message")
                    raise

            # if there was a response
            if self.dataSocket in socks and socks[self.dataSocket] == zmq.POLLIN:

                try:
                    multipartMessage = self.dataSocket.recv_multipart()
                except:
                    self.log.error("Receiving data..failed.", exc_info=True)
                    return [None, None]


                if multipartMessage[0] == b"ALIVE_TEST":
                    continue
                elif len(multipartMessage) < 2:
                    self.log.error("Received mutipart-message is too short. Either config or file content is missing.")
                    self.log.debug("multipartMessage=" + str(mutipartMessage)[:100])
                    return [None, None]

                # extract multipart message
                try:
                    metadata = cPickle.loads(multipartMessage[0])
                except:
                    self.log.error("Could not extract metadata from the multipart-message.", exc_info=True)
                    metadata = None

                #TODO validate multipartMessage (like correct dict-values for metadata)

                try:
                    payload = multipartMessage[1]
                except:
                    self.log.warning("An empty file was received within the multipart-message", exc_info=True)
                    payload = None

                return [metadata, payload]
            else:
                self.log.warning("Could not receive data in the given time.")

                if self.queryNextStarted :
                    try:
                        self.requestSocket.send_multipart(["CANCEL", self.queryNextStarted])
                    except Exception as e:
                        self.log.error("Could not cancel the next query", exc_info=True)

                return [None, None]


    def store (self, targetBasePath, dataObject):

        if type(dataObject) is not list and len(dataObject) != 2:
            raise FormatError("Wrong input type for 'store'")

        payloadMetadata   = dataObject[0]
        payload           = dataObject[1]


        if type(payloadMetadata) is not dict:
            raise FormatError("payload: Wrong input format in 'store'")

        #save all chunks to file
        while True:

            #TODO check if payload != cPickle.dumps(None) ?
            if payloadMetadata and payload:
                #append to file
                try:
                    self.log.debug("append to file based on multipart-message...")
                    #TODO: save message to file using a thread (avoids blocking)
                    #TODO: instead of open/close file for each chunk recyle the file-descriptor for all chunks opened
                    self.__appendChunksToFile(targetBasePath, payloadMetadata, payload)
                    self.log.debug("append to file based on multipart-message...success.")
                except KeyboardInterrupt:
                    self.log.info("KeyboardInterrupt detected. Unable to append multipart-content to file.")
                    break
                except Exception, e:
                    self.log.error("Unable to append multipart-content to file.", exc_info=True)
                    self.log.debug("Append to file based on multipart-message...failed.")

                if len(payload) < payloadMetadata["chunkSize"] :
                    #indicated end of file. Leave loop
                    filename    = self.generateTargetFilepath(targetBasePath, payloadMetadata)
                    fileModTime = payloadMetadata["fileModTime"]

                    self.log.info("New file with modification time " + str(fileModTime) + " received and saved: " + str(filename))
                    break

            try:
                [payloadMetadata, payload] = self.get()
            except:
                self.log.error("Getting data failed.", exc_info=True)
                break
예제 #30
0
class FrankFancyStreamingInterface(object):
	"""
	Abstraction layer to the graph streamer as well as the central logger
	Uses direct (non encrypted) socket connection to the streaming server
	It uses an (encrypted) zeromq connection to the logger
	"""

	ConvertStatus = {
		"Cells" : {
			0 : 5, #removing
			1 : 4, #allocating
			2 : 6  #blacklisting
		}
	}

	#TODO: give every scheduler an unique topic to easily distinguish between them on the queue
	def __init__(self, name, privatekey, VisualizerHost, root_id, ZeromqHost = "*", empty=False):
		"""
		Calls internal methods to open the connections to both the Active Live visualizer and the logger

		:param VisualizerHost: The ip of the FrankFancyGraphStreamer
		:type VisualizerHost: str
		:param ZeromqHost: which interface the zeromq service needs to bind too ("*" for all interfaces)
		:type ZeromqHost: str
		:param KeyFolder: The folder with all the keys, as generated by generate_certificates.py
		:type KeyFolder: str
		:param root_id: the root of the network: LBR
		:type root_id: str
		:return:
		"""

		self.Active = None
		self.Logger = None
		self.EventId = 0
		self.Name = name #used as topic on the queue
		if not empty:
			if privatekey is not None:
				self._connectLogger(privatekey, Host=ZeromqHost)
			if VisualizerHost is not None:
				self._connectVisualizer(VisualizerHost, root_id)
				self.g = DoDAG(root_id, root_id)
				self.root_id = root_id

	def _connectVisualizer(self, Host, root_id):
		"""
		Connect to the Active Live Visualizer

		:param Host: The ip of the FrankFancyGraphStreamer
		:param root_id: the ip6 address of the root node of the network
		:return:
		"""
		try:
			logg.debug("Connecting Streaming Interface to Active Viewer")
			self.Active = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
			self.Active.connect((Host, 600))
			logg.debug("Sending to Active Viewer:{}".format(root_id))
			self.Active.sendall(root_id)

		except:
			logg.debug("Connection to Active Viewer failed!")
			self.Active = None

	def _connectLogger(self, key, Host="localhost"):
		"""
		Open a zeromq queue with publisher service

		:param Host: which interface the zeromq service needs to bind too ("*" for all interfaces)
		:param key: privatekey file of the scheduler
		:return:
		"""
		#TODO: error handling on certificates missing and stuff
		#TODO: expose more security options such as white/blacklisting ips and domain filtering
		self.context = zmq.Context()
		self.auth = ThreadAuthenticator(self.context)
		self.auth.start()
		self.auth.configure_curve(domain='*', location=os.path.join("keys", "public"))

		self.Logger = self.context.socket(zmq.PUB)
		scheduler_public, scheduler_secret = zmq.auth.load_certificate(os.path.join("keys", "plexi1.key_secret"))
		self.Logger.curve_secretkey = scheduler_secret
		self.Logger.curve_publickey = scheduler_public
		self.Logger.curve_server = True
		self.Logger.bind("tcp://127.0.0.1:6000")
		# raw_input("Press enter when the logger has opened subscription to us")


	def SendActiveJson(self,data):
		"""
		Sends an object as json encoded to the Active Live Viewer

		:param data: the object to be send
		:return:
		"""
		if self.Active is not None:
			logg.debug("Sending json data to Active: " + json.dumps(data))
			self.Active.sendall(json.dumps(data))

	def PublishLogging(self,LoggingName="zmq.auth", root_topic="zmq.auth"):
		"""
		Publishes the given python logger to the publishing service

		:param LoggingName: Name of the python logger service
		:type LoggingName: str
		:param root_topic: the topic given with message. is appended with .<LEVEL>
		:type root_topic: str
		:return:
		"""
		handler = PUBHandler(self.Logger)
		handler.root_topic = root_topic
		handler.formatters[logging.DEBUG] = logging.Formatter(fmt='%(asctime)s\t%(levelname)s: %(message)s', datefmt='%H:%M:%S')
		handler.formatters[logging.INFO] = logging.Formatter(fmt='%(asctime)s\t%(levelname)s: %(message)s', datefmt='%H:%M:%S')
		l = logging.getLogger(LoggingName)
		l.addHandler(handler)

	def ChangeCell(self, who, slotoffs, channeloffs, frame, ID, status):
		"""
		Notifies all active services about the changes to a cell in the schedule matrix

		:param who: The node in which the cell is changed
		:type who: :class: `node.NodeID`
		:param slotoffs: slot offset
		:param channeloffs: channel offset
		:param frame: frame name
		:param ID: local cell id
		:param status: new status of the cell
		:return:
		"""
		if self.Active is not None:
			logg.debug("Sending ChangeCell to active viewer")
			self.Active.sendall(json.dumps(["changecell",{"who": str(who), "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID, "status":status}]))
		if self.Logger is not None:
			self.EventId += 1
			logg.debug("Sending ChangeCell to logger, EventID:" + str(self.EventId))
			# self.Logger.send_multipart([self.Name.encode(), pickle.dumps({
			# 	"EventId"		: self.EventId,
			# 	"SubjectId" 	: self.ConvertStatus["Cells"][status],
			# 	"InfoString" 	: json.dumps({"who": who, "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID})
			# })])
			self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, self.ConvertStatus["Cells"][status], time.time(), json.dumps({"node_id": str(who), "channeloffs":channeloffs, "slotoffs":slotoffs, "frame":frame, "id":ID})))])

	def DumpDotData(self, labels={}):
		"""
		dumps an entire dot file to the active viewer. This is not used for the logger

		:return:
		"""
		# packet = "[\"" + str(self.root_id) + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"," + json.dumps(dotdata) + "]"
		if self.Active is not None:
			logg.debug("Sending dotdata")
			# self.Active.sendall(bytearray("[\"" + root_id + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"," + dotdata + "]"))
			dotdata = self.g.draw_graph(labels=labels)
			self.Active.sendall(bytearray(json.dumps(["\"" + self.root_id + " at " + time.strftime("%Y-%m-%d %H:%M:%S") + "\"", dotdata])))
			time.sleep(.5)

	def AddNode(self, node_id, parent):
		"""
		Sends a notification of joining node to the logger

		:param node_id: ip6 of the node
		:type node_id: str
		:param parent: ip6 of the parent node
		:type parent: str
		:return:
		"""
		node_id = str(node_id)
		parent = str(parent)
		if self.Logger is not None:
			self.EventId += 1
			logg.debug("Sending Addnode to logger, EventID:" + str(self.EventId))
			# self.Logger.send_multipart([self.Name.encode(), pickle.dumps({
			# 	"EventId"	: self.EventId,
			# 	"SubjectId"	: 0,
			# 	"InfoString": json.dumps({"node_id" : node_id, "parent" : parent})
			# })])
			self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 0, time.time(), json.dumps({"node_id" : str(node_id), "parent" : str(parent)})))])
		if self.Active is not None:
			logg.debug("Sending Addnode to Active Visualizer, node:{}, parent:{}".format(node_id, parent))
			if parent == "root":
				self.g.attach_node(node_id)
			else:
				self.g.attach_child(node_id, parent)
				self.DumpDotData()

	def RewireNode(self, node_id, old_parent, new_parent):
		"""
		Notifies the logger of a rewire that happened in the network

		:param node_id: ip6 of the node that has rewired
		:param old_parent: ip6 of the old parent
		:param new_parent: ip6 of the new parent
		:return:
		"""
		node_id = str(node_id)
		old_parent = str(old_parent)
		new_parent = str(new_parent)
		if self.Logger is not None:
			self.EventId += 1
			logg.debug("Sending RewireNode to logger, EventID: " + str(self.EventId))
			# self.Logger.send_multipart([self.Name.encode(), pickle.dumps({
			# 	"EventId"	: self.EventId,
			# 	"SubjectId"	: 2,
			# 	"InfoString": json.dumps({"node_id" : node_id, "old_parent" : old_parent, "new_parent" : new_parent})
			# })])
			self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 2, time.time(), json.dumps({"node_id" : str(node_id), "old_parent" : str(old_parent), "new_parent" : str(new_parent)})))])
		if self.Active is not None:
			logg.debug("Sending Rewire to the Active")
			self.g.attach_child(node_id, new_parent)
			self.DumpDotData()


	def RemoveNode(self, node_id):
		"""
		Notifies the logger of a disconnected node

		:param node_id: ip6 of the node that has disconnected
		:return:
		"""
		node_id = str(node_id)
		if self.Logger is not None:
			self.EventId += 1
			logg.debug("Sending RemoveNode to logger, EventID: " + str(self.EventId))
			# self.Logger.send_multipart([self.Name.encode(), pickle.dumps({
			# 	"EventId"	: self.EventId,
			# 	"SubjectId"	: 1,
			# 	"InfoString": json.dumps({"node_id" : node_id})
			# })])
			self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 1, time.time(), json.dumps({"node_id" : str(node_id)})))])
		if self.Active is not None:
			self.g.detach_node(node_id)
			self.DumpDotData()

	def RegisterFrame(self, num_cells, framename):
		"""
		Notifies the logger of a new frame that is defined in the scheduler algorithm

		:param num_cells: number of cells per channel
		:param framename: unique identifieng name
		:return:
		"""
		if self.Logger is not None:
			self.EventId += 1
			logg.debug("Sending RegisterFrame to logger, EventID: " + str(self.EventId))
			self.Logger.send_multipart([self.Name.encode(), pickle.dumps(Event(self.EventId, 7, time.time(), json.dumps({"cells" : num_cells, "name" : framename})))])


	def RegisterFrames(self, frames):
		if self.Active is not None:
			logg.debug("Sending RegisterFrames to Active")
			self.Active.sendall(bytearray(json.dumps(frames)))
예제 #31
0
class Command(LAVADaemonCommand):
    help = "LAVA log recorder"
    logger = None
    default_logfile = "/var/log/lava-server/lava-logs.log"

    def __init__(self, *args, **options):
        super(Command, self).__init__(*args, **options)
        self.logger = logging.getLogger("lava-logs")
        self.log_socket = None
        self.auth = None
        self.controler = None
        self.inotify_fd = None
        self.pipe_r = None
        self.poller = None
        self.cert_dir_path = None
        # List of logs
        self.jobs = {}
        # Keep test cases in memory
        self.test_cases = []
        # Master status
        self.last_ping = 0
        self.ping_interval = TIMEOUT

    def add_arguments(self, parser):
        super(Command, self).add_arguments(parser)

        net = parser.add_argument_group("network")
        net.add_argument('--socket',
                         default='tcp://*:5555',
                         help="Socket waiting for logs. Default: tcp://*:5555")
        net.add_argument('--master-socket',
                         default='tcp://localhost:5556',
                         help="Socket for master-slave communication. Default: tcp://localhost:5556")
        net.add_argument('--ipv6', default=False, action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt', default=False, action='store_true',
                         help="Encrypt messages")
        net.add_argument('--master-cert',
                         default='/etc/lava-dispatcher/certificates.d/master.key_secret',
                         help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-logs", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        # Create the sockets
        context = zmq.Context()
        self.log_socket = context.socket(zmq.PULL)
        self.controler = context.socket(zmq.ROUTER)
        self.controler.setsockopt(zmq.IDENTITY, b"lava-logs")
        # Limit the number of messages in the queue
        self.controler.setsockopt(zmq.SNDHWM, 2)
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc5
        # "Immediately readies that connection for data transfer with the master"
        self.controler.setsockopt(zmq.CONNECT_RID, b"master")

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.log_socket.setsockopt(zmq.IPV6, 1)
            self.controler.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs'])
                self.auth.configure_curve(domain='*', location=options['slaves_certs'])
            except IOError as err:
                self.logger.error("[INIT] %s", err)
                self.auth.stop()
                return
            self.log_socket.curve_publickey = master_public
            self.log_socket.curve_secretkey = master_secret
            self.log_socket.curve_server = True
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_serverkey = master_public

        self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
        self.cert_dir_path = options["slaves_certs"]
        self.inotify_fd = watch_directory(options["slaves_certs"])
        if self.inotify_fd is None:
            self.logger.error("[INIT] Unable to start inotify")

        self.log_socket.bind(options['socket'])
        self.controler.connect(options['master_socket'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.log_socket, zmq.POLLIN)
        self.poller.register(self.controler, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] listening for logs")
        # PING right now: the master is waiting for this message to start
        # scheduling.
        self.controler.send_multipart([b"master", b"PING"])

        try:
            self.main_loop()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)

        # Close the controler socket
        self.controler.close(linger=0)
        self.poller.unregister(self.controler)

        # Carefully close the logging socket as we don't want to lose messages
        self.logger.info("[EXIT] Disconnect logging socket and process messages")
        endpoint = u(self.log_socket.getsockopt(zmq.LAST_ENDPOINT))
        self.logger.debug("[EXIT] unbinding from '%s'", endpoint)
        self.log_socket.unbind(endpoint)

        # Empty the queue
        try:
            while self.wait_for_messages(True):
                # Flush test cases cache for every iteration because we might
                # get killed soon.
                self.flush_test_cases()
        except BaseException as exc:
            self.logger.error("[EXIT] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Last flush
            self.flush_test_cases()
            self.logger.info("[EXIT] Closing the logging socket: the queue is empty")
            self.log_socket.close()
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def flush_test_cases(self):
        if self.test_cases:
            self.logger.info("Saving %d test cases", len(self.test_cases))
            TestCase.objects.bulk_create(self.test_cases)
            self.test_cases = []

    def main_loop(self):
        last_gc = time.time()
        last_bulk_create = time.time()

        # Wait for messages
        # TODO: fix timeout computation
        while self.wait_for_messages(False):
            now = time.time()

            # Dump TestCase into the database
            if now - last_bulk_create > BULK_CREATE_TIMEOUT:
                last_bulk_create = now
                self.flush_test_cases()

            # Close old file handlers
            if now - last_gc > FD_TIMEOUT:
                last_gc = now
                # Iterate while removing keys is not compatible with iterator
                for job_id in list(self.jobs.keys()):  # pylint: disable=consider-iterating-dictionary
                    if now - self.jobs[job_id].last_usage > FD_TIMEOUT:
                        self.logger.info("[%s] closing log file", job_id)
                        self.jobs[job_id].close()
                        del self.jobs[job_id]

            # Ping the master
            if now - self.last_ping > self.ping_interval:
                self.logger.debug("PING => master")
                self.last_ping = now
                self.controler.send_multipart([b"master", b"PING"])

    def wait_for_messages(self, leaving):
        try:
            try:
                sockets = dict(self.poller.poll(TIMEOUT * 1000))
            except zmq.error.ZMQError as exc:
                self.logger.error("[POLL] zmq error: %s", str(exc))
                return True

            # Messages
            if sockets.get(self.log_socket) == zmq.POLLIN:
                self.logging_socket()
                return True

            # Signals
            elif sockets.get(self.pipe_r) == zmq.POLLIN:
                # remove the message from the queue
                os.read(self.pipe_r, 1)

                if not leaving:
                    self.logger.info("[POLL] received a signal, leaving")
                    return False
                else:
                    self.logger.warning("[POLL] signal already handled, please wait for the process to exit")
                    return True

            # Pong received
            elif sockets.get(self.controler) == zmq.POLLIN:
                self.controler_socket()
                return True

            # Inotify socket
            if sockets.get(self.inotify_fd) == zmq.POLLIN:
                os.read(self.inotify_fd, 4096)
                self.logger.debug("[AUTH] Reloading certificates from %s",
                                  self.cert_dir_path)
                self.auth.configure_curve(domain='*',
                                          location=self.cert_dir_path)

            # Nothing received
            else:
                return not leaving

        except (OperationalError, InterfaceError):
            self.logger.info("[RESET] database connection reset")
            connection.close()
        return True

    def logging_socket(self):
        msg = self.log_socket.recv_multipart()
        try:
            (job_id, message) = (u(m) for m in msg)  # pylint: disable=unbalanced-tuple-unpacking
        except ValueError:
            # do not let a bad message stop the master.
            self.logger.error("[POLL] failed to parse log message, skipping: %s", msg)
            return

        try:
            scanned = yaml.load(message, Loader=yaml.CLoader)
        except yaml.YAMLError:
            self.logger.error("[%s] data are not valid YAML, dropping", job_id)
            return

        # Look for "results" level
        try:
            message_lvl = scanned["lvl"]
            message_msg = scanned["msg"]
        except TypeError:
            self.logger.error("[%s] not a dictionary, dropping", job_id)
            return
        except KeyError:
            self.logger.error(
                "[%s] invalid log line, missing \"lvl\" or \"msg\" keys: %s",
                job_id, message)
            return

        # Find the handler (if available)
        if job_id not in self.jobs:
            # Query the database for the job
            try:
                job = TestJob.objects.get(id=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return

            self.logger.info("[%s] receiving logs from a new job", job_id)
            # Create the sub directories (if needed)
            mkdir(job.output_dir)
            self.jobs[job_id] = JobHandler(job)

        if message_lvl == "results":
            try:
                job = TestJob.objects.get(pk=job_id)
            except TestJob.DoesNotExist:
                self.logger.error("[%s] unknown job id", job_id)
                return
            meta_filename = create_metadata_store(message_msg, job)
            new_test_case = map_scanned_results(results=message_msg, job=job,
                                                meta_filename=meta_filename)
            if new_test_case is None:
                self.logger.warning(
                    "[%s] unable to map scanned results: %s",
                    job_id, message)
            else:
                self.test_cases.append(new_test_case)

            # Look for lava.job result
            if message_msg.get("definition") == "lava" and message_msg.get("case") == "job":
                # Flush cached test cases
                self.flush_test_cases()

                if message_msg.get("result") == "pass":
                    health = TestJob.HEALTH_COMPLETE
                    health_msg = "Complete"
                else:
                    health = TestJob.HEALTH_INCOMPLETE
                    health_msg = "Incomplete"
                self.logger.info("[%s] job status: %s", job_id, health_msg)

                infrastructure_error = (message_msg.get("error_type") in ["Bug",
                                                                          "Configuration",
                                                                          "Infrastructure"])
                if infrastructure_error:
                    self.logger.info("[%s] Infrastructure error", job_id)

                # Update status.
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)
                    job.go_state_finished(health, infrastructure_error)
                    job.save()

        # Mark the file handler as used
        self.jobs[job_id].last_usage = time.time()

        # n.b. logging here would produce a log entry for every message in every job.
        # The format is a list of dictionaries
        message = "- %s" % message

        # Write data
        self.jobs[job_id].write(message)

    def controler_socket(self):
        msg = self.controler.recv_multipart()
        try:
            master_id = u(msg[0])
            action = u(msg[1])
            ping_interval = int(msg[2])

            if master_id != "master":
                self.logger.error("Invalid master id '%s'. Should be 'master'",
                                  master_id)
                return
            if action != "PONG":
                self.logger.error("Invalid answer '%s'. Should be 'PONG'",
                                  action)
                return
        except (IndexError, ValueError):
            self.logger.error("Invalid message '%s'", msg)
            return

        if ping_interval < TIMEOUT:
            self.logger.error("invalid ping interval (%d) too small", ping_interval)
            return

        self.logger.debug("master => PONG(%d)", ping_interval)
        self.ping_interval = ping_interval
예제 #32
0
  def run(self):
    self.set_status("Server Startup")
    
    self.set_status("Creating zmq Contexts",1)
    serverctx = zmq.Context() 
    
    self.set_status("Starting zmq ThreadedAuthenticator",1)
    #serverauth = zmq.auth.ThreadedAuthenticator(serverctx)
    serverauth = ThreadAuthenticator(serverctx)
    serverauth.start()
    
    with taco.globals.settings_lock:
      bindip     = taco.globals.settings["Application IP"]
      bindport   = taco.globals.settings["Application Port"]
      localuuid  = taco.globals.settings["Local UUID"]
      publicdir  = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/public/"))
      privatedir = os.path.normpath(os.path.abspath(taco.globals.settings["TacoNET Certificates Store"] + "/"  + taco.globals.settings["Local UUID"] + "/private/"))

    self.set_status("Configuring Curve to use publickey dir:" + publicdir)
    serverauth.configure_curve(domain='*', location=publicdir)
    #auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)

    self.set_status("Creating Server Context",1)
    server = serverctx.socket(zmq.REP)
    server.setsockopt(zmq.LINGER, 0)

    self.set_status("Loading Server Certs",1)
    server_public, server_secret = zmq.auth.load_certificate(os.path.normpath(os.path.abspath(privatedir + "/" + taco.constants.KEY_GENERATION_PREFIX +"-server.key_secret")))
    server.curve_secretkey = server_secret
    server.curve_publickey = server_public
   
    server.curve_server = True
    if bindip == "0.0.0.0": bindip ="*"
    self.set_status("Server is now listening for encrypted ZMQ connections @ "+ "tcp://" + bindip +":" + str(bindport)) 
    server.bind("tcp://" + bindip +":" + str(bindport))
    
    poller = zmq.Poller()
    poller.register(server, zmq.POLLIN|zmq.POLLOUT)

    while not self.stop.is_set():
      socks = dict(poller.poll(200))
      if server in socks and socks[server] == zmq.POLLIN:
        #self.set_status("Getting a request")
        data = server.recv()
        with taco.globals.download_limiter_lock: taco.globals.download_limiter.add(len(data))
        (client_uuid,reply) = taco.commands.Proccess_Request(data)
        if client_uuid!="0": self.set_client_last_request(client_uuid)
      socks = dict(poller.poll(10))
      if server in socks and socks[server] == zmq.POLLOUT:
        #self.set_status("Replying to a request")
        with taco.globals.upload_limiter_lock: taco.globals.upload_limiter.add(len(reply))
        server.send(reply)

        


    self.set_status("Stopping zmq server with 0 second linger")
    server.close(0)
    self.set_status("Stopping zmq ThreadedAuthenticator")
    serverauth.stop() 
    serverctx.term()
    self.set_status("Server Exit")    
예제 #33
0
class ZmqConnector:
    context = None
    auth = None
    public_keys_dir = None
    secret_keys_dir = None
    puller = None
    publisher = None

    HOST = ''

    opponent_id = None
    available_player = 'XXX'

    # Client message protocol:
    #
    # 1. ID: Player ID (random string created by online broker)
    # 2. ACTION: status, command, recipient.
    # 3. MATCH: relevant player match data

    # Server message protocol:
    #
    # 1. Recipient: Player ID of recipient, used for filtering
    # 2. ACTION: sender (SERVER or opponent player's ID), command (welcome, wait, ready, play)
    # 3. Data: forwarded payload

    def __init__(self, host='127.0.0.1'):
        print("[zmq] Initializing ZMQ client object...")
        self.HOST = host
        self.context = zmq.Context()

    def setup(self):
        if not self.check_folder_structure():
            return None
        else:
            self.server_auth()
            self.bind_pull()
            self.bind_pub()

    def check_folder_structure(self):
        keys_dir = os.path.join(os.getcwd(), '../certs')
        print(f"[#] checking folder structure: {keys_dir}")
        self.public_keys_dir = os.path.join(
            keys_dir, 'public')  # has the public keys of registered clients
        self.secret_keys_dir = os.path.join(
            keys_dir, 'private')  # has the server's private cert

        if not os.path.exists(keys_dir) \
         and not os.path.exists(self.public_keys_dir) \
         and not os.path.exists(self.secret_keys_dir):
            print("[!!] Certificates folders are missing")
            return False
        else:
            return True

    def server_auth(self):
        # Start an authenticator for this context
        print("[#] Starting authenticator...")
        self.auth = ThreadAuthenticator(self.context)
        self.auth.start()
        self.auth.allow(self.HOST)
        # give authenticator access to approved clients' certificate directory
        self.auth.configure_curve(domain='*', location=self.public_keys_dir)

    def bind_pull(self, port=5555):
        print("[zmq] Binding PULL socket : {}".format(port))
        self.puller = self.context.socket(zmq.PULL)
        # feed certificates to socket
        server_secret_file = os.path.join(self.secret_keys_dir,
                                          "server.key_secret")
        self.puller.curve_publickey, self.puller.curve_secretkey = zmq.auth.load_certificate(
            server_secret_file)
        self.puller.curve_server = True  # must come before bind
        self.puller.bind("tcp://*:{}".format(port))

    def pull_receive_multi(self):
        try:
            # message = self.puller.recv_multipart(flags=zmq.DONTWAIT)
            message = self.puller.recv_multipart()
            print(f"[zmq] Received :\n\t{datetime.datetime.now()}- {message}")
            return message
        except zmq.Again as a:
            # print("[!zmq!] Error while getting messages: {}".format(a))
            # print(traceback.format_exc())
            return None
        except zmq.ZMQError as e:
            print("[!zmq!] Error while getting messages: {}".format(e))
            print(traceback.format_exc())
            return None

    def bind_pub(self, port=5556):
        print("[zmq] Binding PUB socket: {}".format(port))
        self.publisher = self.context.socket(zmq.PUB)
        # feed own and approved certificates to socket
        server_secret_file = os.path.join(self.secret_keys_dir,
                                          "server.key_secret")
        self.publisher.curve_publickey, self.publisher.curve_secretkey = zmq.auth.load_certificate(
            server_secret_file)
        self.publisher.curve_server = True  # must come before bind
        self.publisher.bind("tcp://*:{}".format(port))

    def send(self, recipient, info, payload):
        message = list()
        message.append(recipient.encode())
        message.append(json.dumps(info).encode())
        message.append(json.dumps(payload).encode())
        self.pub_send_multi(message)

    def pub_send_multi(self, message):
        try:
            self.publisher.send_multipart(message)
            print(f"[zmq] Sent :\n\t{datetime.datetime.now()}- {message}")
        except TypeError as e:
            print("[!zmq!] TypeError while sending message: {}".format(e))
            print(traceback.format_exc())
        except ValueError as e:
            print("[!zmq!] ValueError while sending message: {}".format(e))
            print(traceback.format_exc())
        except zmq.ZMQError as e:
            print("[!zmq!] ZMQError while sending message: {}".format(e))
            print(traceback.format_exc())

    # GENERIC FUNCTIONS
    def disconnect(self):
        print("[zmq] Disconnecting client...")
        for socket in (self.publisher, self.puller):
            if socket is not None:
                socket.close()
        self.context.term()
예제 #34
0
class MultiNodeAgent(BEMOSSAgent):
    '''Listens to everything and publishes a heartbeat according to the
    heartbeat period specified in the settings module.
    '''
    def __init__(self, config_path, **kwargs):
        super(MultiNodeAgent, self).__init__(**kwargs)
        #self.node_health = dict()
        #self.node_last_sync = dict()

        self.agent_id = 'multinodeagent'
        self.identity = self.agent_id

        self.multinode_status = dict()
        self.is_parent = False
        self.last_sync_with_parent = datetime(1991, 1,
                                              1)  #equivalent to -ve infinitive
        self.parent_node = None
        self.curcon = None  #initialize database connection.
        self.recently_online_node_list = []  # initialize to lists to empty
        self.recently_offline_node_list = [
        ]  # they will be filled as nodes are discovered to be online/offline

        self.offline_variables = offline_variables
        self.offline_variables['logged_by'] = self.agent_id
        self.offline_table = offline_table
        self.offline_log_variables = offline_log_variables

    def getMultinodeData(self):
        self.multinode_data = db_helper.get_multinode_data()

        self.nodelist_dict = {
            node['name']: node
            for node in self.multinode_data['known_nodes']
        }
        self.node_name_list = [
            node['name'] for node in self.multinode_data['known_nodes']
        ]
        self.address_list = [
            node['address'] for node in self.multinode_data['known_nodes']
        ]
        self.server_key_list = [
            node['server_key'] for node in self.multinode_data['known_nodes']
        ]
        self.node_name = self.multinode_data['this_node']

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == self.node_name:
                self.node_index = index
                break
        else:
            raise ValueError(
                '"this_node:" entry on the multinode_data json file is invalid'
            )

        for node_name in self.node_name_list:  #initialize all nodes data
            if node_name not in self.multinode_status:  #initialize new nodes. There could be already the node if this getMultiNode
                # data is being called later
                self.multinode_status[node_name] = dict()
                self.multinode_status[node_name][
                    'health'] = -10  #initialized; never online/offline
                self.multinode_status[node_name]['last_sync_time'] = datetime(
                    1991, 1, 1)
                self.multinode_status[node_name]['last_online_time'] = None
                self.multinode_status[node_name]['last_offline_time'] = None
                self.multinode_status[node_name]['last_scanned_time'] = None

    def configure_authenticator(self):
        self.auth.allow()
        # Tell authenticator to use the certificate in a directory
        self.auth.configure_curve(domain='*', location=self.public_keys_dir)

    @Core.receiver('onsetup')
    def onsetup(self, sender, **kwargs):
        print "Setup"
        self.getMultinodeData()

        base_dir = settings.PROJECT_DIR + "/Agents/MultiNodeAgent/"
        public_keys_dir = os.path.abspath(os.path.join(base_dir,
                                                       'public_keys'))
        secret_keys_dir = os.path.abspath(
            os.path.join(base_dir, 'private_keys'))

        self.secret_keys_dir = secret_keys_dir
        self.public_keys_dir = public_keys_dir

        if not (os.path.exists(public_keys_dir)
                and os.path.exists(secret_keys_dir)):
            logging.critical(
                "Certificates are missing - run generate_certificates.py script first"
            )
            sys.exit(1)

        ctx = zmq.Context.instance()
        self.ctx = ctx
        # Start an authenticator for this context.
        self.auth = ThreadAuthenticator(ctx)
        self.auth.start()
        self.configure_authenticator()

        server = ctx.socket(zmq.PUB)

        server_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['server_secret_key']
        server_secret_file = os.path.join(secret_keys_dir,
                                          server_secret_key_filename)
        server_public, server_secret = zmq.auth.load_certificate(
            server_secret_file)
        server.curve_secretkey = server_secret
        server.curve_publickey = server_public
        server.curve_server = True  # must come before bind
        server.bind(
            self.multinode_data['known_nodes'][self.node_index]['address'])
        self.server = server

    def check_if_parent(self):
        if self.node_name == self.node_name_list[
                0]:  #The first entry is the parent; always
            self.is_parent = True
            self.node_index = 0
            print "I am the boss now, " + self.node_name
            # start the web-server
            subprocess.check_output(settings.PROJECT_DIR +
                                    "/start_webserver.sh " +
                                    settings.PROJECT_DIR,
                                    shell=True)
            message = dict()
            message[STATUS_CHANGE.AGENT_ID] = 'devicediscoveryagent'
            message[STATUS_CHANGE.NODE] = str(self.node_index)
            message[STATUS_CHANGE.AGENT_STATUS] = 'start'
            message[STATUS_CHANGE.
                    NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT
            self.bemoss_publish('status_change', 'networkagent', [message])
            self.updateParent(self.node_name)
            print "discoveryagent started"

    def disperseMessage(self, topic, header, message):
        for node_name in self.node_name_list:
            if node_name == self.node_name:
                continue
            self.server.send(
                jsonify(node_name + '/republish/' + topic, message))

    def republishToParent(self, topic, header, message):
        if self.is_parent:
            return  #if I am parent, the message is already published
        for node_name in self.node_name_list:
            if self.multinode_status[node_name][
                    'health'] == 2:  #health = 2 is the parent node
                self.server.send(
                    jsonify(node_name + '/republish/' + topic, message))

    @Core.periodic(20)
    def send_heartbeat(self):
        # self.vip.pubsub.publish('pubsub', 'listener', None, {'message': 'Hello Listener'})
        # print 'publishing'
        print "Sending heartbeat"

        last_sync_string = self.last_sync_with_parent.strftime(
            '%B %d, %Y, %H:%M:%S')
        self.server.send(
            jsonify(
                'heartbeat/' + self.node_name + '/' + str(self.is_parent) +
                '/' + last_sync_string, ""))

    def extract_ip(self, addr):
        return re.search(r'([0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})',
                         addr).groups()[0]

    def getNodeId(self, node_name):

        for index, node in enumerate(self.multinode_data['known_nodes']):
            if node['name'] == node_name:
                node_index = index
                break
        else:
            raise ValueError('the node name: ' + node_name +
                             ' is not found in multinode data')

        return node_index

    def getNodeName(self, node_id):
        return self.multinode_data['known_nodes'][node_id]['name']

    def handle_offline_nodes(self, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that node on this node
            command_group = []
            for node_name in node_name_list:
                node_id = self.getNodeId(node_name)
                #put the offline event into cassandra events log table, and also create notification
                self.offline_variables['date_id'] = str(datetime.now().date())
                self.offline_variables['time'] = datetime.utcnow()
                self.offline_variables['agent_id'] = node_name
                self.offline_variables['event'] = 'node-offline'
                self.offline_variables['reason'] = 'communication-error'
                self.offline_variables['related_to'] = None
                self.offline_variables['event_id'] = uuid.uuid4()
                self.offline_variables['logged_time'] = datetime.utcnow()
                self.TSDCustomInsert(all_vars=self.offline_variables,
                                     log_vars=self.offline_log_variables,
                                     tablename=self.offline_table)
                time = date_converter.UTCToLocal(datetime.utcnow())
                message = str(
                    node_name
                ) + ': ' + 'node-offline. Reason: possibly communiation-error'
                self.curcon.execute(
                    "select id from possible_events where event_name=%s",
                    ('node-offline', ))
                event_id = self.curcon.fetchone()[0]
                self.curcon.execute(
                    "insert into notification (dt_triggered, seen, event_type_id, message) VALUES (%s, %s, %s, %s)",
                    (time, False, event_id, message))
                self.curcon.commit()

                # get a list of agents that were supposedly running in that offline node
                self.curcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))

                if self.curcon.rowcount:
                    agent_ids = self.curcon.fetchall()

                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.AGENT_STATUS] = 'start'
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.TEMPORARY
                        command_group += [message]
            print "moving agents from offline node to parent: " + str(
                node_name_list)
            print command_group
            if command_group:
                self.bemoss_publish('status_change', 'networkagent',
                                    command_group)

    def handle_online_nodes(self, node_name_list):
        if self.is_parent:
            # start all the agents belonging to that nodes back on them
            command_group = []
            for node_name in node_name_list:

                node_id = self.getNodeId(node_name)

                # put the online event into cassandra events log table, and also create notification
                self.offline_variables['date_id'] = str(datetime.now().date())
                self.offline_variables['time'] = datetime.utcnow()
                self.offline_variables['agent_id'] = node_name
                self.offline_variables['event'] = 'node-online'
                self.offline_variables['reason'] = 'communication-restored'
                self.offline_variables['related_to'] = None
                self.offline_variables['event_id'] = uuid.uuid4()
                self.offline_variables['logged_time'] = datetime.utcnow()
                self.TSDCustomInsert(all_vars=self.offline_variables,
                                     log_vars=self.offline_log_variables,
                                     tablename=self.offline_table)
                time = date_converter.UTCToLocal(datetime.utcnow())
                message = str(
                    node_name
                ) + ': ' + 'node-online. Reason: possibly communiation-restored'
                self.curcon.execute(
                    "select id from possible_events where event_name=%s",
                    ('node-online', ))
                event_id = self.curcon.fetchone()[0]
                self.curcon.execute(
                    "insert into notification (dt_triggered, seen, event_type_id, message) VALUES (%s, %s, %s, %s)",
                    (time, False, event_id, message))
                self.curcon.commit()

                #get a list of agents that were supposed to be running in that online node
                self.curcon.execute(
                    "SELECT agent_id FROM " + node_devices_table +
                    " WHERE assigned_node_id=%s", (node_id, ))
                if self.curcon.rowcount:
                    agent_ids = self.curcon.fetchall()
                    for agent_id in agent_ids:
                        message = dict()
                        message[STATUS_CHANGE.AGENT_ID] = agent_id[0]
                        message[
                            STATUS_CHANGE.
                            NODE_ASSIGNMENT_TYPE] = ZONE_ASSIGNMENT_TYPES.PERMANENT
                        message[STATUS_CHANGE.NODE] = str(self.node_index)
                        message[STATUS_CHANGE.
                                AGENT_STATUS] = 'stop'  #stop in this node
                        command_group += [message]
                        message = dict(message)  #create another copy
                        message[STATUS_CHANGE.NODE] = str(node_id)
                        message[
                            STATUS_CHANGE.
                            AGENT_STATUS] = 'start'  #start in the target node
                        command_group += [message]

            print "Moving agents back to the online node: " + str(
                node_name_list)
            print command_group

            if command_group:
                self.bemoss_publish('status_change', 'networkagent',
                                    command_group)

    def updateParent(self, parent_node_name):
        parent_ip = self.extract_ip(
            self.nodelist_dict[parent_node_name]['address'])
        write_new = False
        if not os.path.isfile(settings.MULTINODE_PARENT_IP_FILE
                              ):  # but parent file doesn't exists
            write_new = True
        else:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'r') as f:
                read_ip = f.read()
            if read_ip != parent_ip:
                write_new = True
        if write_new:
            with open(settings.MULTINODE_PARENT_IP_FILE, 'w') as f:
                f.write(parent_ip)
            if self.curcon:
                self.curcon.close()  #close old connection
            self.curcon = db_connection(
            )  #start new connection using new parent_ip
            self.vip.pubsub.publish('pubsub',
                                    'from/multinodeagent/update_parent')

    @Core.periodic(60)
    def check_health(self):

        for node_name, node in self.multinode_status.items():
            if node['health'] > 0:  #initialize all online nodes to 0. If they are really online, they should change it
                #  back to 1 or 2 (parent) within 30 seconds throught the heartbeat.
                node['health'] = 0

        gevent.sleep(30)
        parent_node_name = None  #initialize parent node
        online_node_exists = False
        for node_name, node in self.multinode_status.items():
            node['last_scanned_time'] = datetime.now()
            if node['health'] == 0:
                node['health'] = -1
                node['last_offline_time'] = datetime.now()
                self.recently_offline_node_list += [node_name]
            elif node['health'] == -1:  #offline since long
                pass
            elif node[
                    'health'] == -10:  #The node was initialized to -10, and never came online. Treat it as recently going
                # offline for this iteration so that the agents that were supposed to be running there can be migrated
                node['health'] = -1
                self.recently_offline_node_list += [node_name]
            elif node['health'] == 2:  #there is some parent node present
                parent_node_name = node_name
            if node['health'] > 0:
                online_node_exists = True  #At-least one node (itself) should be online, if not some problem

        assert online_node_exists, "At least one node (current node) must be online"
        if parent_node_name:  #parent node exist
            self.updateParent(parent_node_name)

        for node in self.multinode_data['known_nodes']:
            print node['name'] + ': ' + str(
                self.multinode_status[node['name']]['health'])

        if self.is_parent:
            #if this is a parent node, update the node_info table
            if self.curcon is None:  #if no database connection exists make connection
                self.curcon = db_connection()

            tbl_node_info = settings.DATABASES['default']['TABLE_node_info']
            self.curcon.execute('select node_id from ' + tbl_node_info)
            to_be_deleted_node_ids = self.curcon.fetchall()
            for index, node in enumerate(self.multinode_data['known_nodes']):
                if (index, ) in to_be_deleted_node_ids:
                    to_be_deleted_node_ids.remove(
                        (index, ))  #don't remove this current node
                result = self.curcon.execute(
                    'select * from ' + tbl_node_info + ' where node_id=%s',
                    (index, ))
                node_type = 'parent' if self.multinode_status[
                    node['name']]['health'] == 2 else "child"
                node_status = "ONLINE" if self.multinode_status[
                    node['name']]['health'] > 0 else "OFFLINE"
                ip_address = self.extract_ip(node['address'])
                last_scanned_time = self.multinode_status[
                    node['name']]['last_online_time']
                last_offline_time = self.multinode_status[
                    node['name']]['last_offline_time']
                last_sync_time = self.multinode_status[
                    node['name']]['last_sync_time']

                var_list = "(node_id,node_name,node_type,node_status,ip_address,last_scanned_time,last_offline_time,last_sync_time)"
                value_placeholder_list = "(%s,%s,%s,%s,%s,%s,%s,%s)"
                actual_values_list = (index, node['name'], node_type,
                                      node_status, ip_address,
                                      last_scanned_time, last_offline_time,
                                      last_sync_time)

                if self.curcon.rowcount == 0:
                    self.curcon.execute(
                        "insert into " + tbl_node_info + " " + var_list +
                        " VALUES" + value_placeholder_list, actual_values_list)
                else:
                    self.curcon.execute(
                        "update " + tbl_node_info + " SET " + var_list +
                        " = " + value_placeholder_list + " where node_id = %s",
                        actual_values_list + (index, ))
            self.curcon.commit()

            for id in to_be_deleted_node_ids:
                self.curcon.execute(
                    'delete from accounts_userprofile_nodes where nodeinfo_id=%s',
                    id)  #delete entries in user-profile for the old node
                self.curcon.commit()
                self.curcon.execute('delete from ' + tbl_node_info +
                                    ' where node_id=%s',
                                    id)  #delete the old nodes
                self.curcon.commit()

            if self.recently_online_node_list:  #Online nodes should be handled first because, the same node can first be
                #on both recently_online_node_list and recently_offline_node_list, when it goes offline shortly after
                #coming online
                self.handle_online_nodes(self.recently_online_node_list)
                self.recently_online_node_list = []  # reset after handling
            if self.recently_offline_node_list:
                self.handle_offline_nodes(self.recently_offline_node_list)
                self.recently_offline_node_list = []  # reset after handling

    def connect_client(self, node):
        server_public_file = os.path.join(self.public_keys_dir,
                                          node['server_key'])
        server_public, _ = zmq.auth.load_certificate(server_public_file)
        # The client must know the server's public key to make a CURVE connection.
        self.client.curve_serverkey = server_public
        self.client.setsockopt(zmq.SUBSCRIBE, 'heartbeat/')
        self.client.setsockopt(zmq.SUBSCRIBE, self.node_name)
        self.client.connect(node['address'])

    def disconnect_client(self, node):
        self.client.disconnect(node['address'])

    @Core.receiver('onstart')
    def onstart(self, sender, **kwargs):

        self.check_if_parent()
        print "Starting to receive Heart-beat"
        self.vip.heartbeat.start_with_period(15)
        client = self.ctx.socket(zmq.SUB)
        # We need two certificates, one for the client and one for
        # the server. The client must know the server's public key
        # to make a CURVE connection.

        client_secret_key_filename = self.multinode_data['known_nodes'][
            self.node_index]['client_secret_key']
        client_secret_file = os.path.join(self.secret_keys_dir,
                                          client_secret_key_filename)
        client_public, client_secret = zmq.auth.load_certificate(
            client_secret_file)
        client.curve_secretkey = client_secret
        client.curve_publickey = client_public

        self.client = client

        for node in self.multinode_data['known_nodes']:
            self.connect_client(node)

        print "Starting to listen"
        try:
            while True:  #read messages
                if client.poll(1000):
                    topic, msg = dejsonify(client.recv())
                    topic_list = topic.split('/')
                    if topic_list[0] == 'heartbeat':
                        node_name = topic_list[1]
                        is_parent = topic_list[2]
                        last_sync_with_parent = topic_list[3]
                        if self.multinode_status[node_name][
                                'health'] < 0:  #the node health was <0 , means offline
                            print node_name + " is back online"
                            self.recently_online_node_list += [node_name]

                        if is_parent.lower() in ['false', '0']:
                            self.multinode_status[node_name]['health'] = 1
                        elif is_parent.lower() in ['true', '1']:
                            self.multinode_status[node_name]['health'] = 2
                            self.parent_node = node_name
                        else:
                            raise ValueError(
                                'Invalid is_parent string in heart-beat message'
                            )

                        self.multinode_status[node_name][
                            'last_online_time'] = datetime.now()

                    if topic_list[0] == self.node_name:
                        #message addressed to this node

                        if topic_list[1] == 'republish':
                            new_topic = '/'.join(
                                topic_list[2:] +
                                ['repub-by-' + self.node_name, 'republished'])
                            self.vip.pubsub.publish('pubsub', new_topic, None,
                                                    msg)

                    print self.node_name + ": " + topic, str(msg)

                else:
                    gevent.sleep(2)

        except Exception as er:
            print "error"
            print er

        # stop auth thread
        self.auth.stop()

    @PubSub.subscribe('pubsub', 'to/multinodeagent/')
    def updateMultinodeData(self, peer, sender, bus, topic, headers, message):
        print "Updating Multinode data"
        topic_list = topic.split('/')
        self.configure_authenticator()
        #to/multinodeagent/from/<doesn't matter>/update_multinode_data
        if topic_list[4] == 'update_multinode_data':
            old_multinode_data = self.multinode_data
            self.getMultinodeData()
            for node in self.multinode_data['known_nodes']:
                if node not in old_multinode_data['known_nodes']:
                    print "New node has been added to the cluster: " + node[
                        'name']
                    print "We will connect to it"
                    self.connect_client(node)

            for node in old_multinode_data['known_nodes']:
                if node not in self.multinode_data['known_nodes']:
                    print "Node has been removed from the cluster: " + node[
                        'name']
                    print "We will disconnect from it"
                    self.disconnect_client(node)
                    # TODO: remove it from the node_info table

        print "yay! got it"

    @PubSub.subscribe('pubsub', 'to/')
    def relayToMessage(self, peer, sender, bus, topic, headers, message):
        print topic
        topic_list = topic.split('/')
        #to/<some_agent_or_ui>/topic/from/<some_agent_or_ui>
        to_index = topic_list.index('to') + 1
        if 'from' in topic_list:
            from_index = topic_list.index('from') + 1
            from_entity = topic_list[from_index]

        to_entity = topic_list[to_index]
        last_field = topic_list[-1]
        if last_field == 'republished':  #it is already a republished message, no need to republish
            return
        if to_entity in settings.SYSTEM_AGENTS:
            self.disperseMessage(topic, headers,
                                 message)  #republish to all nodes
        elif to_entity in settings.PARENT_NODE_SYSTEM_AGENTS:
            if not self.is_parent:
                self.republishToParent(topic, headers, message)
        else:
            self.curcon.execute(
                "SELECT current_node_id FROM " + node_devices_table +
                " WHERE agent_id=%s", (to_entity, ))
            if self.curcon.rowcount:
                node_id = self.curcon.fetchone()[0]
                if node_id != self.node_index:
                    self.server.send(
                        jsonify(
                            self.getNodeName(node_id) + '/republish/' + topic,
                            message))

    @PubSub.subscribe('pubsub', 'from/')
    def relayFromMessage(self, peer, sender, bus, topic, headers, message):
        topic_list = topic.split('/')
        #from/<some_agent_or_ui>/topic
        from_entity = topic_list[1]
        last_field = topic_list[-1]
        if last_field == 'republished':  #it is a republished message, no need to publish
            return
        self.disperseMessage(topic, headers, message)  #republish to all nodes

    @PubSub.subscribe('pubsub', '')
    def on_match(self, peer, sender, bus, topic, headers, message):
        '''Use match_all to receive all messages and print them out.'''
        if sender == 'pubsub.compat':
            message = compat.unpack_legacy_message(headers, message)
예제 #35
0
class RpcClient:
    """"""
    def __init__(self):
        """Constructor"""
        # zmq port related
        self.__context: zmq.Context = zmq.Context()

        # Request socket (Request–reply pattern)
        self.__socket_req: zmq.Socket = self.__context.socket(zmq.REQ)

        # Subscribe socket (Publish–subscribe pattern)
        self.__socket_sub: zmq.Socket = self.__context.socket(zmq.SUB)

        # Worker thread relate, used to process data pushed from server
        self.__active: bool = False  # RpcClient status
        self.__thread: threading.Thread = None  # RpcClient thread
        self.__lock: threading.Lock = threading.Lock()

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

        self._last_received_ping: datetime = datetime.utcnow()

    @lru_cache(100)
    def __getattr__(self, name: str):
        """
        Realize remote call function
        """

        # Perform remote call task
        def dorpc(*args, **kwargs):
            # Get timeout value from kwargs, default value is 30 seconds
            if "timeout" in kwargs:
                timeout = kwargs.pop("timeout")
            else:
                timeout = 30000

            # Generate request
            req = [name, args, kwargs]

            # Send request and wait for response
            with self.__lock:
                self.__socket_req.send_pyobj(req)

                # Timeout reached without any data
                n = self.__socket_req.poll(timeout)
                if not n:
                    msg = f"Timeout of {timeout}ms reached for {req}"
                    raise RemoteException(msg)

                rep = self.__socket_req.recv_pyobj()

            # Return response if successed; Trigger exception if failed
            if rep[0]:
                return rep[1]
            else:
                raise RemoteException(rep[1])

        return dorpc

    def start(self,
              req_address: str,
              sub_address: str,
              client_secretkey_path: str = "",
              server_publickey_path: str = "",
              username: str = "",
              password: str = "") -> None:
        """
        Start RpcClient
        """
        if self.__active:
            return

        # Start authenticator
        if client_secretkey_path and server_publickey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                client_secretkey_path)
            serverkey, _ = zmq.auth.load_certificate(server_publickey_path)

            self.__socket_sub.curve_secretkey = secretkey
            self.__socket_sub.curve_publickey = publickey
            self.__socket_sub.curve_serverkey = serverkey

            self.__socket_req.curve_secretkey = secretkey
            self.__socket_req.curve_publickey = publickey
            self.__socket_req.curve_serverkey = serverkey
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_sub.plain_username = username.encode()
            self.__socket_sub.plain_password = password.encode()

            self.__socket_req.plain_username = username.encode()
            self.__socket_req.plain_password = password.encode()

        # Connect zmq port
        self.__socket_req.connect(req_address)
        self.__socket_sub.connect(sub_address)

        # Start RpcClient status
        self.__active = True

        # Start RpcClient thread
        self.__thread = threading.Thread(target=self.run)
        self.__thread.start()

        self._last_received_ping = datetime.utcnow()

    def stop(self) -> None:
        """
        Stop RpcClient
        """
        if not self.__active:
            return

        # Stop RpcClient status
        self.__active = False

    def join(self) -> None:
        # Wait for RpcClient thread to exit
        if self.__thread and self.__thread.is_alive():
            self.__thread.join()
        self.__thread = None

    def run(self) -> None:
        """
        Run RpcClient function
        """
        pull_tolerance = int(KEEP_ALIVE_TOLERANCE.total_seconds() * 1000)

        while self.__active:
            if not self.__socket_sub.poll(pull_tolerance):
                self.on_disconnected()
                continue

            # Receive data from subscribe socket
            topic, data = self.__socket_sub.recv_pyobj(flags=NOBLOCK)

            if topic == KEEP_ALIVE_TOPIC:
                self._last_received_ping = data
            else:
                # Process data by callable function
                self.callback(topic, data)

        # Close socket
        self.__socket_req.close()
        self.__socket_sub.close()

    def callback(self, topic: str, data: Any) -> None:
        """
        Callable function
        """
        raise NotImplementedError

    def subscribe_topic(self, topic: str) -> None:
        """
        Subscribe data
        """
        self.__socket_sub.setsockopt_string(zmq.SUBSCRIBE, topic)

    def on_disconnected(self):
        """
        Callback when heartbeat is lost.
        """
        print(
            "RpcServer has no response over {tolerance} seconds, please check you connection."
            .format(tolerance=KEEP_ALIVE_TOLERANCE.total_seconds()))
예제 #36
0
class RpcServer:
    """"""
    def __init__(self):
        """
        Constructor
        """
        # Save functions dict: key is fuction name, value is fuction object
        self.__functions: Dict[str, Any] = {}

        # Zmq port related
        self.__context: zmq.Context = zmq.Context()

        # Reply socket (Request–reply pattern)
        self.__socket_rep: zmq.Socket = self.__context.socket(zmq.REP)

        # Publish socket (Publish–subscribe pattern)
        self.__socket_pub: zmq.Socket = self.__context.socket(zmq.PUB)

        # Worker thread related
        self.__active: bool = False  # RpcServer status
        self.__thread: threading.Thread = None  # RpcServer thread
        self.__lock: threading.Lock = threading.Lock()

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

    def is_active(self) -> bool:
        """"""
        return self.__active

    def start(self,
              rep_address: str,
              pub_address: str,
              server_secretkey_path: str = "",
              username: str = "",
              password: str = "") -> None:
        """
        Start RpcServer
        """
        if self.__active:
            return

        # Start authenticator
        if server_secretkey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                server_secretkey_path)

            self.__socket_pub.curve_secretkey = secretkey
            self.__socket_pub.curve_publickey = publickey
            self.__socket_pub.curve_server = True

            self.__socket_rep.curve_secretkey = secretkey
            self.__socket_rep.curve_publickey = publickey
            self.__socket_rep.curve_server = True
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_pub.plain_server = True
            self.__socket_rep.plain_server = True

        # Bind socket address
        self.__socket_rep.bind(rep_address)
        self.__socket_pub.bind(pub_address)

        # Start RpcServer status
        self.__active = True

        # Start RpcServer thread
        self.__thread = threading.Thread(target=self.run)
        self.__thread.start()

    def stop(self) -> None:
        """
        Stop RpcServer
        """
        if not self.__active:
            return

        # Stop RpcServer status
        self.__active = False

    def join(self) -> None:
        # Wait for RpcServer thread to exit
        if self.__thread and self.__thread.is_alive():
            self.__thread.join()
        self.__thread = None

    def run(self) -> None:
        """
        Run RpcServer functions
        """
        start = datetime.utcnow()

        while self.__active:
            # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
            cur = datetime.utcnow()
            delta = cur - start

            if delta >= KEEP_ALIVE_INTERVAL:
                self.publish(KEEP_ALIVE_TOPIC, cur)

            if not self.__socket_rep.poll(1000):
                continue

            # Receive request data from Reply socket
            req = self.__socket_rep.recv_pyobj()

            # Get function name and parameters
            name, args, kwargs = req

            # Try to get and execute callable function object; capture exception information if it fails
            try:
                func = self.__functions[name]
                r = func(*args, **kwargs)
                rep = [True, r]
            except Exception as e:  # noqa
                rep = [False, traceback.format_exc()]

            # send callable response by Reply socket
            self.__socket_rep.send_pyobj(rep)

        # Unbind socket address
        self.__socket_pub.unbind(self.__socket_pub.LAST_ENDPOINT)
        self.__socket_rep.unbind(self.__socket_rep.LAST_ENDPOINT)

    def publish(self, topic: str, data: Any) -> None:
        """
        Publish data
        """
        with self.__lock:
            self.__socket_pub.send_pyobj([topic, data])

    def register(self, func: Callable) -> None:
        """
        Register function
        """
        self.__functions[func.__name__] = func
예제 #37
0
class StratusApp(StratusServerApp):
    def __init__(self, core: StratusCore, **kwargs):
        StratusServerApp.__init__(self, core, **kwargs)
        self.logger = StratusLogger.getLogger()
        self.active = True
        self.parms = self.getConfigParms('stratus')
        self.client_address = self.parms.get("client_address", "*")
        self.request_port = self.parms.get("request_port", 4556)
        self.response_port = self.parms.get("response_port", 4557)
        self.active_handlers = {}
        self.getCertDirs()

    def getCertDirs(
        self
    ):  # These directories are generated by the generate_certificates script
        self.cert_dir = self.parms.get("certificate_path",
                                       os.path.expanduser("~/.stratus/zmq"))
        self.logger.info(
            f"Loading certificates and keys from directory {self.cert_dir}")
        self.keys_dir = os.path.join(self.cert_dir, 'certificates')
        self.public_keys_dir = os.path.join(self.cert_dir, 'public_keys')
        self.secret_keys_dir = os.path.join(self.cert_dir, 'private_keys')

        if not (os.path.exists(self.keys_dir)
                and os.path.exists(self.public_keys_dir)
                and os.path.exists(self.secret_keys_dir)):
            from stratus.handlers.zeromq.security.generate_certificates import generate_certificates
            generate_certificates(self.cert_dir)

    def initSocket(self):
        try:
            server_secret_file = os.path.join(self.secret_keys_dir,
                                              "server.key_secret")
            server_public, server_secret = zmq.auth.load_certificate(
                server_secret_file)
            # TODO: this is commented to avoid key checking
            #self.request_socket.curve_secretkey = server_secret
            #self.request_socket.curve_publickey = server_public
            #self.request_socket.curve_server = True
            self.request_socket.bind("tcp://{}:{}".format(
                self.client_address, self.request_port))
            self.logger.info(
                "@@STRATUS-APP --> Bound authenticated request socket to client at {} on port: {}"
                .format(self.client_address, self.request_port))
        except Exception as err:
            self.logger.error(
                "@@STRATUS-APP: Error initializing request socket on {}, port {}: {}"
                .format(self.client_address, self.request_port, err))
            self.logger.error(traceback.format_exc())

    def addHandler(self, clientId, jobId, handler):
        self.active_handlers[clientId + "-" + jobId] = handler
        return handler

    def removeHandler(self, clientId, jobId):
        handlerId = clientId + "-" + jobId
        try:
            del self.active_handlers[handlerId]
        except:
            self.logger.error("Error removing handler: " + handlerId +
                              ", active handlers = " +
                              str(list(self.active_handlers.keys())))

    def setExeStatus(self, submissionId: str, status: Status):
        self.responder.setExeStatus(submissionId, status)

    def sendResponseMessage(self, msg: StratusResponse) -> str:
        request_args = [msg.id, msg.message]
        packaged_msg = "!".join(request_args)
        timeStamp = datetime.datetime.now().strftime("MM/dd HH:mm:ss")
        self.logger.info(
            "@@STRATUS-APP: Sending response {} on request_socket @({}): {}".
            format(msg.id, timeStamp, str(msg)))
        self.request_socket.send_string(packaged_msg)
        return packaged_msg

    def initInteractions(self):
        try:
            self.zmqContext: zmq.Context = zmq.Context()

            self.auth = ThreadAuthenticator(self.zmqContext)
            self.auth.start()
            self.auth.allow("192.168.0.22")
            self.auth.allow(self.client_address)
            self.auth.configure_curve(
                domain='*', location=zmq.auth.CURVE_ALLOW_ANY
            )  # self.public_keys_dir )  # Use 'location=zmq.auth.CURVE_ALLOW_ANY' for stonehouse security

            self.request_socket: zmq.Socket = self.zmqContext.socket(zmq.REP)
            self.responder = StratusZMQResponder(
                self.zmqContext,
                self.response_port,
                client_address=self.client_address,
                certificate_path=self.cert_dir)
            self.initSocket()
            self.logger.info(
                "@@STRATUS-APP:Listening for requests on port: {}".format(
                    self.request_port))

        except Exception as err:
            self.logger.error(
                "@@STRATUS-APP:  ------------------------------- StratusApp Init error: {} ------------------------------- "
                .format(err))

    def processResults(self):
        completed_workflows = self.responder.processWorkflows(
            self.getWorkflows())
        for rid in completed_workflows:
            self.clearWorkflow(rid)

    def processRequests(self):
        while self.request_socket.poll(0) != 0:
            request_header = self.request_socket.recv_string().strip().strip(
                "'")
            parts = request_header.split("!")
            submissionId = str(parts[0])
            rType = str(parts[1])
            request: Dict = json.loads(parts[2]) if len(parts) > 2 else ""
            try:
                self.logger.info(
                    "@@STRATUS-APP:  ###  Processing {} request: {}".format(
                        rType, request))
                if rType == "capabilities":
                    response = self.core.getCapabilities(request["type"])
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
                elif rType == "exe":
                    if len(parts) <= 2:
                        raise Exception("Missing parameters to exe request")
                    request["rid"] = submissionId
                    self.logger.info(
                        "Processing zmq Request: '{}' '{}' '{}'".format(
                            submissionId, rType, str(request)))
                    self.submitWorkflow(
                        request)  #   TODO: Send results when tasks complete.
                    response = {"status": "Executing"}
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
                elif rType == "quit" or rType == "shutdown":
                    response = {"status": "Terminating"}
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
                    self.logger.info(
                        "@@STRATUS-APP: Received Shutdown Message")
                    exit(0)
                else:
                    msg = "@@STRATUS-APP: Unknown request type: " + rType
                    self.logger.info(msg)
                    response = {"status": "error", "error": msg}
                    self.sendResponseMessage(
                        StratusResponse(submissionId, response))
            except Exception as ex:
                self.processError(submissionId, ex)

    def processError(self, rid: str, ex: Exception):
        tb = traceback.format_exc()
        self.logger.error("@@STRATUS-APP: Execution error: " + str(ex))
        self.logger.error(tb)
        response = {"status": "error", "error": str(ex), "traceback": tb}
        self.sendResponseMessage(StratusResponse(rid, response))

    def updateInteractions(self):
        self.processRequests()
        self.processResults()

    def term(self, msg):
        self.logger.info("@@STRATUS-APP: !!EDAS Shutdown: " + msg)
        self.active = False
        self.auth.stop()
        self.logger.info("@@STRATUS-APP: QUIT PythonWorkerPortal")
        try:
            self.request_socket.close()
        except Exception:
            pass
        self.logger.info("@@STRATUS-APP: CLOSE request_socket")
        self.responder.close_connection()
        self.logger.info("@@STRATUS-APP: TERM responder")
        self.shutdown()
        self.logger.info("@@STRATUS-APP: shutdown complete")
예제 #38
0
파일: device.py 프로젝트: RIAPS/riaps-pycom
class Device(Actor):
    '''
    The actor class implements all the management and control functions over its components
    '''          

    def __init__(self, gModel, gModelName, dName, qName, sysArgv):
        '''
        Constructor
        
        :param dName: device type name
        :type dName: str
        
        :param qName: qualified name of the device instance: 'actor.inst'
        :type qName: str
         
        '''
        self.logger = logging.getLogger(__name__)
        self.inst_ = self
        self.appName = gModel["name"]
        self.modelName = gModelName
        aName,iName = qName.split('.')
        self.name = qName
        self.iName = iName
        self.dName = dName 
        self.pid = os.getpid()
        self.uuid = None
        self.suffix = ""
        self.setupIfaces()
        # Assumption : pid is a 4 byte int
        self.actorID = ipaddress.IPv4Address(self.globalHost).packed + self.pid.to_bytes(4, 'big')
        if dName not in gModel["devices"]:
            raise BuildError('Device "%s" unknown' % dName)
       
        # In order to make the rest of the code work, we build an actor model for the device
        devModel = gModel["devices"][dName]
        self.model = {}  # The made-up actor model
        
        formals = devModel["formals"]  # Formals are the same as those of the device (component)
        self.model["formals"] = formals

        devInst = { "type": dName }  # There is a single instance, containing the device component
        actuals = []
        for arg in  formals:
            name = arg["name"]
            actual = {}
            actual["name"] = name
            actual["param"] = name
            actuals.append(actual)
        devInst["actuals"] = actuals
        
        self.model["instances"] = { iName: devInst}     # Single instance (under iName)
        
        aModel = gModel["actors"][aName]
        self.model["locals"] = aModel["locals"]         # Locals
        self.model["internals"] = aModel["internals"]   # Internals 
        
        self.INT_RE = re.compile(r"^[-]?\d+$")
        self.parseParams(sysArgv)
        
        # Use czmq's context
        czmq_ctx = Zsys.init()
        self.context = zmq.Context.shadow(czmq_ctx.value)
        Zsys.handler_reset()  # Reset previous signal 
        
        # Context for app sockets
        self.appContext = zmq.Context()
        
        if Config.SECURITY:
            (self.public_key, self.private_key) = zmq.auth.load_certificate(const.appCertFile)
            _public = zmq.curve_public(self.private_key)
            if(self.public_key != _public):
                self.logger.error("bad security key(s)")
                raise BuildError("invalid security key(s)")
            hosts = ['127.0.0.1']
            try:
                with open(const.appDescFile, 'r') as f:
                    content = yaml.load(f, Loader=yaml.Loader)
                    hosts += content.hosts
            except:
                self.logger.error("Error loading app descriptor:s", str(sys.exc_info()[1]))

            self.auth = ThreadAuthenticator(self.appContext)
            self.auth.start()
            self.auth.allow(*hosts)
            self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
        else:
            (self.public_key, self.private_key) = (None, None)
            self.auth = None
            self.appContext = self.context
        
        try:
            if os.path.isfile(const.logConfFile) and os.access(const.logConfFile, os.R_OK):
                spdlog_setup.from_file(const.logConfFile)      
        except Exception as e:
            self.logger.error("error while configuring componentLogger: %s" % repr(e))  
        
        messages = gModel["messages"]  # Global message types (global on the network)
        self.messageNames = []
        for messageSpec in messages:
            self.messageNames.append(messageSpec["name"])
                   
        locals_ = self.model["locals"]  # Local message types (local to the host)
        self.localNames = []
        for messageSpec in locals_:
            self.localNames.append(messageSpec["type"]) 
            
        internals = self.model["internals"]  # Internal message types (internal to the actor process)
        self.internalNames = []
        for messageSpec in internals:
            self.internalNames.append(messageSpec["type"])
            
        groups = gModel["groups"]
        self.groupTypes = {} 
        for group in groups:
            self.groupTypes[group["name"]] = { 
                "kind": group["kind"],
                "message": group["message"],
                "timed": group["timed"]
            }
            
        self.components = {}
        instSpecs = self.model["instances"]
        _compSpecs = gModel["components"]
        devSpecs = gModel["devices"]
        for instName in instSpecs:  # Create the component instances: the 'parts'
            instSpec = instSpecs[instName]
            instType = instSpec['type']
            if instType in devSpecs: 
                typeSpec = devSpecs[instType]
            else:
                raise BuildError('Device type "%s" for instance "%s" is undefined' % (instType, instName))
            instFormals = typeSpec['formals']
            instActuals = instSpec['actuals']
            instArgs = self.buildInstArgs(instName, instFormals, instActuals)
            # Check whether the component is C++ component
            ccComponentFile = 'lib' + instType.lower() + '.so'
            ccComp = os.path.isfile(ccComponentFile)
            try:
                if ccComp:
                    modObj = importlib.import_module('lib' + instType.lower())
                    self.components[instName] = modObj.create_component_py(self, self.model,
                                                                           typeSpec, instName,
                                                                           instType, instArgs,
                                                                           self.appName, self.name, groups)
                else:
                    self.components[instName] = Part(self, typeSpec, instName, instType, instArgs)
            except Exception as e:
                traceback.print_exc()
                self.logger.error("Error while constructing part '%s.%s': %s" % (instType, instName, str(e)))
    
    def getPortMessageTypes(self, ports, key, kinds, res):
        for _name, spec in ports[key].items():
            for kind in kinds:
                typeName = spec[kind]
                res.append({"type": typeName})
        
    def getMessageTypes(self, devModel):
        res = []
        ports = devModel["ports"]
        self.getPortMessageTypes(ports, "pubs", ["type"], res)
        self.getPortMessageTypes(ports, "subs", ["type"], res)
        self.getPortMessageTypes(ports, "reqs", ["req_type", "rep_type"], res)
        self.getPortMessageTypes(ports, "reps", ["req_type", "rep_type"], res)
        self.getPortMessageTypes(ports, "clts", ["req_type", "rep_type"], res)
        self.getPortMessageTypes(ports, "srvs", ["req_type", "rep_type"], res)
        self.getPortMessageTypes(ports, "qrys", ["req_type", "rep_type"], res)
        self.getPortMessageTypes(ports, "anss", ["req_type", "rep_type"], res)
        return res
                     
    def isDevice(self):
        return True 
    
    def setup(self):
        '''
        Perform a setup operation on the actor (after  the initial construction but before the activation of parts)
        '''
        self.logger.info("setup")
        # self.setupIfaces()
        self.suffix = self.macAddress
        self.disco = DiscoClient(self, self.suffix)
        self.disco.start()                      # Start the discovery service client
        self.disco.registerActor()              # Register this actor with the discovery service
        self.logger.info("device registered with disco")
        self.deplc = DeplClient(self, self.suffix)
        self.deplc.start()
        ok = self.deplc.registerActor()       
        self.logger.info("device %s registered with depl" % ("is" if ok else "is not"))
        self.controls = { }
        self.controlMap = { }
        for inst in self.components:
            comp = self.components[inst]
            control = self.context.socket(zmq.PAIR)
            control.bind('inproc://part_' + inst + '_control')
            self.controls[inst] = control
            self.controlMap[id(control)] = comp 
            if isinstance(comp, Part):
                self.components[inst].setup(control)
            else:
                self.components[inst].setup()

    def terminate(self):
        self.logger.info("terminating")
        for component in self.components.values():
            component.terminate()
        # self.devc.terminate()
        self.disco.terminate()
        # Clean up everything
        # self.context.destroy()
        time.sleep(1.0)
        self.logger.info("terminated")
        os._exit(0)
예제 #39
0
class FeatureComputer(object):
    def __init__(self,
                 bind_str="tcp://127.0.0.1:5560",
                 parent_model=None,
                 layer=None,
                 logins=None,
                 viewable_layers=None):
        self.context = zmq.Context.instance()
        self.auth = ThreadAuthenticator(self.context)
        self.auth.start()
        #auth.allow('127.0.0.1')
        self.auth.configure_plain(domain='*', passwords=logins)
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.plain_server = True
        self.socket.bind(bind_str)
        self.parent_model = parent_model
        self.curr_model = parent_model
        self.viewable_layers = viewable_layers

        self.config = tf.ConfigProto()
        self.config.gpu_options.per_process_gpu_memory_fraction = 0.3
        self.config.gpu_options.allow_growth = True

        if not layer:
            self.layer = 5  #len(self.parent_model.layers) - 1
        else:
            self.layer = layer

    def change_layer(self, *args, **kwargs):
        print("Changing layer")
        print(args, kwargs)
        self.layer = kwargs.get('layer', self.layer)
        if not self.viewable_layers:
            self.curr_model = Model(
                input=[self.parent_model.layers[0].input],
                output=[self.parent_model.layers[self.layer].output])
        else:
            self.curr_model = Model(
                input=[self.parent_model.layers[0].input],
                output=[self.viewable_layers[self.layer].output])
        #set_session(tf.Session(config=self.config))
        print(self.layer)
        self.socket.send_pyobj({
            'type': 'layer_changed',
            'success': True
        }, zmq.NOBLOCK)

    def get_summary(self, *args, **kwargs):
        self.socket.send_pyobj({
            'type': 'summary',
            'result': None
        }, zmq.NOBLOCK)

    def do_predict(self, *args, **kwargs):
        # TODO: Make this configurable
        input = kwargs.pop('input', np.zeros((1, 224, 224, 3)))

        resized = np.float64(cv2.resize(input, (224, 224)))
        preprocessed = preprocess_input(np.expand_dims(resized, axis=0))

        result = self.curr_model.predict(preprocessed, verbose=0)
        self.socket.send_pyobj({
            'type': 'prediction',
            'result': result
        }, zmq.NOBLOCK)

    def do_layerinfo(self, *args, **kwargs):
        self.socket.send_pyobj(
            {
                'type': 'layer_info',
                'shape': self.curr_model.compute_output_shape(
                    (1, 224, 224, 3)),
                'name': self.parent_model.layers[self.layer].name
            }, zmq.NOBLOCK)

    def do_summary(self, *args, **kwargs):
        if not self.viewable_layers:
            self.socket.send_pyobj(
                {
                    'type': 'summary',
                    'result':
                    [layer.name for layer in self.parent_model.layers]
                }, zmq.NOBLOCK)
        else:
            self.socket.send_pyobj(
                {
                    'type': 'summary',
                    'result': [layer.name for layer in self.viewable_layers]
                }, zmq.NOBLOCK)

    def handle_message(self, message):
        if message['type'] == "change_layer":
            self.change_layer(**message)
        if message['type'] == 'predict':
            self.do_predict(**message)
        if message['type'] == 'layer_info':
            self.do_layerinfo(**message)
        if message['type'] == 'summary':
            self.do_summary(**message)

    def run(self):
        self.running = True
        while self.running:
            message = self.socket.recv_pyobj()
            self.handle_message(message)
예제 #40
0
파일: mispzmq.py 프로젝트: tomking2/MISP
class MispZmq:
    message_count = 0
    publish_count = 0

    monitor_thread = None
    auth = None
    socket = None
    pidfile = None

    r: redis.StrictRedis
    namespace: str

    def __init__(self):
        self._logger = logging.getLogger()

        self.tmp_location = Path(__file__).parent.parent / "tmp"
        self.pidfile = self.tmp_location / "mispzmq.pid"
        if self.pidfile.exists():
            with open(self.pidfile.as_posix()) as f:
                pid = f.read()
            if check_pid(pid):
                raise Exception(
                    "mispzmq already running on PID {}".format(pid))
            else:
                # Cleanup
                self.pidfile.unlink()
        if (self.tmp_location / "mispzmq_settings.json").exists():
            self._setup()
        else:
            raise Exception("The settings file is missing.")

    def _setup(self):
        with open((self.tmp_location /
                   "mispzmq_settings.json").as_posix()) as settings_file:
            self.settings = json.load(settings_file)
        self.namespace = self.settings["redis_namespace"]
        self.r = redis.StrictRedis(host=self.settings["redis_host"],
                                   db=self.settings["redis_database"],
                                   password=self.settings["redis_password"],
                                   port=self.settings["redis_port"],
                                   decode_responses=True)
        self.timestamp_settings = time.time()
        self._logger.debug("Connected to Redis {}:{}/{}".format(
            self.settings["redis_host"], self.settings["redis_port"],
            self.settings["redis_database"]))

    def _setup_zmq(self):
        context = zmq.Context()

        if "username" in self.settings and self.settings["username"]:
            if "password" not in self.settings or not self.settings["password"]:
                raise Exception(
                    "When username is set, password cannot be empty.")

            self.auth = ThreadAuthenticator(context)
            self.auth.start()
            self.auth.configure_plain(domain="*",
                                      passwords={
                                          self.settings["username"]:
                                          self.settings["password"]
                                      })
        else:
            if self.auth:
                self.auth.stop()
            self.auth = None

        self.socket = context.socket(zmq.PUB)
        if self.settings["username"]:
            self.socket.plain_server = True  # must come before bind
        self.socket.bind("tcp://{}:{}".format(self.settings["host"],
                                              self.settings["port"]))
        self._logger.debug("ZMQ listening on tcp://{}:{}".format(
            self.settings["host"], self.settings["port"]))

        if self._logger.isEnabledFor(logging.DEBUG):
            monitor = self.socket.get_monitor_socket()
            self.monitor_thread = threading.Thread(target=event_monitor,
                                                   args=(monitor,
                                                         self._logger))
            self.monitor_thread.start()
        else:
            if self.monitor_thread:
                self.socket.disable_monitor()
            self.monitor_thread = None

    def _handle_command(self, command):
        if command == "kill":
            self._logger.info("Kill command received, shutting down.")
            self.clean()
            sys.exit()

        elif command == "reload":
            self._logger.info(
                "Reload command received, reloading settings from file.")
            self._setup()
            self._setup_zmq()

        elif command == "status":
            self._logger.info(
                "Status command received, responding with latest stats.")
            self.r.delete("{}:status".format(self.namespace))
            self.r.lpush(
                "{}:status".format(self.namespace),
                json.dumps({
                    "timestamp": time.time(),
                    "timestampSettings": self.timestamp_settings,
                    "publishCount": self.publish_count,
                    "messageCount": self.message_count
                }))
        else:
            self._logger.warning(
                "Received invalid command '{}'.".format(command))

    def _create_pid_file(self):
        with open(self.pidfile.as_posix(), "w") as f:
            f.write(str(os.getpid()))

    def _pub_message(self, topic, data):
        self.socket.send_string("{} {}".format(topic, data))

    def clean(self):
        if self.monitor_thread:
            self.socket.disable_monitor()
        if self.auth:
            self.auth.stop()
        if self.socket:
            self.socket.close()
        if self.pidfile:
            self.pidfile.unlink()

    def main(self):
        self._create_pid_file()
        self._setup_zmq()
        time.sleep(1)

        status_array = [
            "And when you're dead I will be still alive.",
            "And believe me I am still alive.",
            "I'm doing science and I'm still alive.",
            "I feel FANTASTIC and I'm still alive.",
            "While you're dying I'll be still alive.",
        ]
        topics = [
            "misp_json", "misp_json_event", "misp_json_attribute",
            "misp_json_sighting", "misp_json_organisation", "misp_json_user",
            "misp_json_conversation", "misp_json_object",
            "misp_json_object_reference", "misp_json_audit", "misp_json_tag",
            "misp_json_warninglist"
        ]

        lists = ["{}:command".format(self.namespace)]
        for topic in topics:
            lists.append("{}:data:{}".format(self.namespace, topic))

        while True:
            data = self.r.blpop(lists, timeout=10)

            if data is None:
                # redis timeout expired
                current_time = int(time.time())
                time_delta = current_time - int(self.timestamp_settings)
                status_entry = int(time_delta / 10 % 5)
                status_message = {
                    "status": status_array[status_entry],
                    "uptime": current_time - int(self.timestamp_settings)
                }
                self._pub_message("misp_json_self", json.dumps(status_message))
                self._logger.debug(
                    "No message received for 10 seconds, sending ZMQ status message."
                )
            else:
                key, value = data
                key = key.replace("{}:".format(self.namespace), "")
                if key == "command":
                    self._handle_command(value)
                elif key.startswith("data:"):
                    topic = key.split(":")[1]
                    self._logger.debug(
                        "Received data for topic '{}', sending to ZMQ.".format(
                            topic))
                    self._pub_message(topic, value)
                    self.message_count += 1
                    if topic == "misp_json":
                        self.publish_count += 1
                else:
                    self._logger.warning(
                        "Received invalid message '{}'.".format(key))
예제 #41
0
class Command(LAVADaemonCommand):
    """
    worker_host is the hostname of the worker this field is set by the admin
    and could therefore be empty in a misconfigured instance.
    """
    logger = None
    help = "LAVA dispatcher master"
    default_logfile = "/var/log/lava-server/lava-master.log"

    def __init__(self, *args, **options):
        super(Command, self).__init__(*args, **options)
        self.auth = None
        self.controler = None
        self.event_socket = None
        self.poller = None
        self.pipe_r = None
        self.inotify_fd = None
        # List of logs
        # List of known dispatchers. At startup do not load this from the
        # database. This will help to know if the slave as restarted or not.
        self.dispatchers = {"lava-logs": SlaveDispatcher("lava-logs", online=False)}
        self.events = {"canceling": set()}

    def add_arguments(self, parser):
        super(Command, self).add_arguments(parser)
        # Important: ensure share/env.yaml is put into /etc/ by setup.py in packaging.
        config = parser.add_argument_group("dispatcher config")

        config.add_argument('--env',
                            default="/etc/lava-server/env.yaml",
                            help="Environment variables for the dispatcher processes. "
                                 "Default: /etc/lava-server/env.yaml")
        config.add_argument('--env-dut',
                            default="/etc/lava-server/env.dut.yaml",
                            help="Environment variables for device under test. "
                                 "Default: /etc/lava-server/env.dut.yaml")
        config.add_argument('--dispatchers-config',
                            default="/etc/lava-server/dispatcher.d",
                            help="Directory that might contain dispatcher specific configuration")

        net = parser.add_argument_group("network")
        net.add_argument('--master-socket',
                         default='tcp://*:5556',
                         help="Socket for master-slave communication. Default: tcp://*:5556")
        net.add_argument('--event-url', default="tcp://localhost:5500",
                         help="URL of the publisher")
        net.add_argument('--ipv6', default=False, action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt', default=False, action='store_true',
                         help="Encrypt messages")
        net.add_argument('--master-cert',
                         default='/etc/lava-dispatcher/certificates.d/master.key_secret',
                         help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def send_status(self, hostname):
        """
        The master crashed, send a STATUS message to get the current state of jobs
        """
        jobs = TestJob.objects.filter(actual_device__worker_host__hostname=hostname,
                                      state=TestJob.STATE_RUNNING)
        for job in jobs:
            self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname,
                             job.actual_device.hostname)
            send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)])

    def dispatcher_alive(self, hostname):
        if hostname not in self.dispatchers:
            # The server crashed: send a STATUS message
            self.logger.warning("Unknown dispatcher <%s> (server crashed)", hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)
            self.send_status(hostname)

        # Mark the dispatcher as alive
        self.dispatchers[hostname].alive()

    def controler_socket(self):
        try:
            # We need here to use the zmq.NOBLOCK flag, otherwise we could block
            # the whole main loop where this function is called.
            msg = self.controler.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False
        # This is way to verbose for production and should only be activated
        # by (and for) developers
        # self.logger.debug("[CC] Receiving: %s", msg)

        # 1: the hostname (see ZMQ documentation)
        hostname = u(msg[0])
        # 2: the action
        action = u(msg[1])

        # Check that lava-logs only send PINGs
        if hostname == "lava-logs" and action != "PING":
            self.logger.error("%s => %s Invalid action from log daemon",
                              hostname, action)
            return True

        # Handle the actions
        if action == 'HELLO' or action == 'HELLO_RETRY':
            self._handle_hello(hostname, action, msg)
        elif action == 'PING':
            self._handle_ping(hostname, action, msg)
        elif action == 'END':
            self._handle_end(hostname, action, msg)
        elif action == 'START_OK':
            self._handle_start_ok(hostname, action, msg)
        else:
            self.logger.error("<%s> sent unknown action=%s, args=(%s)",
                              hostname, action, msg[1:])
        return True

    def read_event_socket(self):
        try:
            msg = self.event_socket.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False

        try:
            (topic, _, dt, username, data) = (u(m) for m in msg)
        except ValueError:
            self.logger.error("Invalid event: %s", msg)
            return True

        if topic.endswith(".testjob"):
            try:
                data = simplejson.loads(data)
                if data["state"] == "Canceling":
                    self.events["canceling"].add(int(data["job"]))
            except ValueError:
                self.logger.error("Invalid event data: %s", msg)
        return True

    def _handle_end(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
            error_msg = msg[3]
            compressed_description = msg[4]
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        try:
            job = TestJob.objects.get(id=job_id)
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
            # ACK even if the job is unknown to let the dispatcher
            # forget about it
            send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
            return

        filename = os.path.join(job.output_dir, 'description.yaml')
        # If description.yaml already exists: a END was already received
        if os.path.exists(filename):
            self.logger.info("[%d] %s => END (duplicated), skipping", job_id, hostname)
        else:
            if compressed_description:
                self.logger.info("[%d] %s => END", job_id, hostname)
            else:
                self.logger.info("[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)",
                                 job_id, hostname)
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)

                    job.go_state_finished(TestJob.HEALTH_INCOMPLETE)
                    if error_msg:
                        self.logger.error("[%d] Error: %s", job_id, error_msg)
                        job.failure_comment = error_msg
                    job.save()

            # Create description.yaml even if it's empty
            # Allows to know when END messages are duplicated
            try:
                # Create the directory if it was not already created
                mkdir(os.path.dirname(filename))
                # TODO: check that compressed_description is not ""
                description = lzma.decompress(compressed_description)
                with open(filename, 'w') as f_description:
                    f_description.write(description.decode("utf-8"))
                if description:
                    parse_job_description(job)
            except (IOError, lzma.LZMAError) as exc:
                self.logger.error("[%d] Unable to dump 'description.yaml'",
                                  job_id)
                self.logger.exception("[%d] %s", job_id, exc)

        # ACK the job and mark the dispatcher as alive
        send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
        self.dispatcher_alive(hostname)

    def _handle_hello(self, hostname, action, msg):
        # Check the protocol version
        try:
            slave_version = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        self.logger.info("%s => %s", hostname, action)
        if slave_version != PROTOCOL_VERSION:
            self.logger.error("<%s> using protocol v%d while master is using v%d",
                              hostname, slave_version, PROTOCOL_VERSION)
            return

        send_multipart_u(self.controler, [hostname, 'HELLO_OK'])
        # If the dispatcher is known and sent an HELLO, means that
        # the slave has restarted
        if hostname in self.dispatchers:
            if action == 'HELLO':
                self.logger.warning("Dispatcher <%s> has RESTARTED",
                                    hostname)
            else:
                # Assume the HELLO command was received, and the
                # action succeeded.
                self.logger.warning("Dispatcher <%s> was not confirmed",
                                    hostname)
        else:
            # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO
            # message.
            self.logger.warning("New dispatcher <%s>", hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)

        # Mark the dispatcher as alive
        self.dispatcher_alive(hostname)

    def _handle_ping(self, hostname, action, msg):  # pylint: disable=unused-argument
        self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL)
        # Send back a signal
        send_multipart_u(self.controler, [hostname, 'PONG', str(PING_INTERVAL)])
        self.dispatcher_alive(hostname)

    def _handle_start_ok(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return
        self.logger.info("[%d] %s => START_OK", job_id, hostname)
        try:
            with transaction.atomic():
                # TODO: find a way to lock actual_device
                job = TestJob.objects.select_for_update() \
                                     .get(id=job_id)
                job.go_state_running()
                job.save()
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
        else:
            self.dispatcher_alive(hostname)

    def export_definition(self, job):  # pylint: disable=no-self-use
        job_def = yaml.load(job.definition)
        job_def['compatibility'] = job.pipeline_compatibility

        # no need for the dispatcher to retain comments
        return yaml.dump(job_def)

    def save_job_config(self, job, worker, device_cfg, options):
        output_dir = job.output_dir
        mkdir(output_dir)
        with open(os.path.join(output_dir, "job.yaml"), "w") as f_out:
            f_out.write(self.export_definition(job))
        with contextlib.suppress(IOError):
            shutil.copy(options["env"], os.path.join(output_dir, "env.yaml"))
        with contextlib.suppress(IOError):
            shutil.copy(options["env_dut"], os.path.join(output_dir, "env.dut.yaml"))
        with contextlib.suppress(IOError):
            shutil.copy(os.path.join(options["dispatchers_config"], "%s.yaml" % worker.hostname),
                        os.path.join(output_dir, "dispatcher.yaml"))
        with open(os.path.join(output_dir, "device.yaml"), "w") as f_out:
            yaml.dump(device_cfg, f_out)

    def start_job(self, job, options):
        # Load job definition to get the variables for template
        # rendering
        job_def = yaml.load(job.definition)
        job_ctx = job_def.get('context', {})

        device = job.actual_device
        worker = device.worker_host

        # Load configurations
        env_str = load_optional_yaml_file(options['env'])
        env_dut_str = load_optional_yaml_file(options['env_dut'])
        device_cfg = device.load_configuration(job_ctx)
        dispatcher_cfg_file = os.path.join(options['dispatchers_config'],
                                           "%s.yaml" % worker.hostname)
        dispatcher_cfg = load_optional_yaml_file(dispatcher_cfg_file)

        self.save_job_config(job, worker, device_cfg, options)
        self.logger.info("[%d] START => %s (%s)", job.id,
                         worker.hostname, device.hostname)
        send_multipart_u(self.controler,
                         [worker.hostname, 'START', str(job.id),
                          self.export_definition(job),
                          yaml.dump(device_cfg),
                          dispatcher_cfg, env_str, env_dut_str])

        # For multinode jobs, start the dynamic connections
        parent = job
        for sub_job in job.sub_jobs_list:
            if sub_job == parent or not sub_job.dynamic_connection:
                continue

            # inherit only enough configuration for dynamic_connection operation
            self.logger.info("[%d] Trimming dynamic connection device configuration.", sub_job.id)
            min_device_cfg = parent.actual_device.minimise_configuration(device_cfg)

            self.save_job_config(sub_job, worker, min_device_cfg, options)
            self.logger.info("[%d] START => %s (connection)",
                             sub_job.id, worker.hostname)
            send_multipart_u(self.controler,
                             [worker.hostname, 'START',
                              str(sub_job.id),
                              self.export_definition(sub_job),
                              yaml.dump(min_device_cfg), dispatcher_cfg,
                              env_str, env_dut_str])

    def start_jobs(self, options):
        """
        Loop on all scheduled jobs and send the START message to the slave.
        """
        # make the request atomic
        query = TestJob.objects.select_for_update()
        # Only select test job that are ready
        query = query.filter(state=TestJob.STATE_SCHEDULED)
        # Only start jobs on online workers
        query = query.filter(actual_device__worker_host__state=Worker.STATE_ONLINE)
        # exclude test job without a device: they are special test jobs like
        # dynamic connection.
        query = query.exclude(actual_device=None)
        # TODO: find a way to lock actual_device

        # Loop on all jobs
        for job in query:
            msg = None
            try:
                self.start_job(job, options)
            except jinja2.TemplateNotFound as exc:
                self.logger.error("[%d] Template not found: '%s'",
                                  job.id, exc.message)
                msg = "Template not found: '%s'" % exc.message
            except jinja2.TemplateSyntaxError as exc:
                self.logger.error("[%d] Template syntax error in '%s', line %d: %s",
                                  job.id, exc.name, exc.lineno, exc.message)
                msg = "Template syntax error in '%s', line %d: %s" % (exc.name, exc.lineno, exc.message)
            except IOError as exc:
                self.logger.error("[%d] Unable to read '%s': %s",
                                  job.id, exc.filename, exc.strerror)
                msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror)
            except yaml.YAMLError as exc:
                self.logger.error("[%d] Unable to parse job definition: %s",
                                  job.id, exc)
                msg = "Cannot parse job definition: %s" % exc

            if msg:
                # Add the error as lava.job result
                metadata = {"case": "job",
                            "definition": "lava",
                            "error_type": "Infrastructure",
                            "error_msg": msg,
                            "result": "fail"}
                suite, _ = TestSuite.objects.get_or_create(name="lava", job=job)
                TestCase.objects.create(name="job", suite=suite, result=TestCase.RESULT_FAIL,
                                        metadata=yaml.dump(metadata))
                job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True)
                job.save()

    def cancel_jobs(self, partial=False):
        query = TestJob.objects.filter(state=TestJob.STATE_CANCELING)
        if partial:
            query = query.filter(id__in=list(self.events["canceling"]))

        for job in query:
            worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host
            self.logger.info("[%d] CANCEL => %s", job.id,
                             worker.hostname)
            send_multipart_u(self.controler,
                             [worker.hostname, 'CANCEL', str(job.id)])

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-master", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        self.logger.info("[INIT] Marking all workers as offline")
        with transaction.atomic():
            for worker in Worker.objects.select_for_update().all():
                worker.go_state_offline()
                worker.save()

        # Create the sockets
        context = zmq.Context()
        self.controler = context.socket(zmq.ROUTER)
        self.event_socket = context.socket(zmq.SUB)

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.controler.setsockopt(zmq.IPV6, 1)
            self.event_socket.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s", options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s", options['slaves_certs'])
                self.auth.configure_curve(domain='*', location=options['slaves_certs'])
            except IOError as err:
                self.logger.error(err)
                self.auth.stop()
                return
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_server = True

            self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
            self.inotify_fd = watch_directory(options["slaves_certs"])
            if self.inotify_fd is None:
                self.logger.error("[INIT] Unable to start inotify")

        self.controler.setsockopt(zmq.IDENTITY, b"master")
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42
        # "If two clients use the same identity when connecting to a ROUTER
        # [...] the ROUTER socket shall hand-over the connection to the new
        # client and disconnect the existing one."
        self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1)
        self.controler.bind(options['master_socket'])

        self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC))
        self.event_socket.connect(options['event_url'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.controler, zmq.POLLIN)
        self.poller.register(self.event_socket, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] LAVA master has started.")
        self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION)

        try:
            self.main_loop(options)
        except BaseException as exc:
            self.logger.error("[CLOSE] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Drop controler socket: the protocol does handle lost messages
            self.logger.info("[CLOSE] Closing the controler socket and dropping messages")
            self.controler.close(linger=0)
            self.event_socket.close(linger=0)
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def main_loop(self, options):
        last_schedule = last_dispatcher_check = time.time()

        while True:
            try:
                try:
                    # Compute the timeout
                    now = time.time()
                    timeout = min(SCHEDULE_INTERVAL - (now - last_schedule),
                                  PING_INTERVAL - (now - last_dispatcher_check))
                    # If some actions are remaining, decrease the timeout
                    if self.events["canceling"]:
                        timeout = min(timeout, 1)
                    # Wait at least for 1ms
                    timeout = max(timeout * 1000, 1)

                    # Wait for data or a timeout
                    sockets = dict(self.poller.poll(timeout))
                except zmq.error.ZMQError:
                    continue

                if sockets.get(self.pipe_r) == zmq.POLLIN:
                    self.logger.info("[POLL] Received a signal, leaving")
                    break

                # Command socket
                if sockets.get(self.controler) == zmq.POLLIN:
                    while self.controler_socket():  # Unqueue all pending messages
                        pass

                # Events socket
                if sockets.get(self.event_socket) == zmq.POLLIN:
                    while self.read_event_socket():  # Unqueue all pending messages
                        pass
                    # Wait for the next iteration to handle the event.
                    # In fact, the code that generated the event (lava-logs or
                    # lava-server-gunicorn) needs some time to commit the
                    # database transaction.
                    # If we are too fast, the database object won't be
                    # available (or in the right state) yet.
                    continue

                # Inotify socket
                if sockets.get(self.inotify_fd) == zmq.POLLIN:
                    os.read(self.inotify_fd, 4096)
                    self.logger.debug("[AUTH] Reloading certificates from %s",
                                      options['slaves_certs'])
                    self.auth.configure_curve(domain='*', location=options['slaves_certs'])

                # Check dispatchers status
                now = time.time()
                if now - last_dispatcher_check > PING_INTERVAL:
                    for hostname, dispatcher in self.dispatchers.items():
                        if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT:
                            if hostname == "lava-logs":
                                self.logger.error("[STATE] lava-logs goes OFFLINE")
                            else:
                                self.logger.error("[STATE] Dispatcher <%s> goes OFFLINE", hostname)
                            self.dispatchers[hostname].go_offline()
                    last_dispatcher_check = now

                # Limit accesses to the database. This will also limit the rate of
                # CANCEL and START messages
                if time.time() - last_schedule > SCHEDULE_INTERVAL:
                    if self.dispatchers["lava-logs"].online:
                        schedule(self.logger)

                        # Dispatch scheduled jobs
                        with transaction.atomic():
                            self.start_jobs(options)
                    else:
                        self.logger.warning("lava-logs is offline: can't schedule jobs")

                    # Handle canceling jobs
                    self.cancel_jobs()

                    # Do not count the time taken to schedule jobs
                    last_schedule = time.time()
                else:
                    # Cancel the jobs and remove the jobs from the set
                    if self.events["canceling"]:
                        self.cancel_jobs(partial=True)
                        self.events["canceling"] = set()

            except (OperationalError, InterfaceError):
                self.logger.info("[RESET] database connection reset.")
                # Closing the database connection will force Django to reopen
                # the connection
                connection.close()
                time.sleep(2)
예제 #42
0
class CombaZMQAdapter(threading.Thread, CombaBase):
    
    def __init__(self, port):

        self.port = str(port)
        threading.Thread.__init__ (self)
        self.shutdown_event = Event()
        self.context = zmq.Context().instance()
        self.authserver = ThreadAuthenticator(self.context)
        self.loadConfig()
        self.start()

    #------------------------------------------------------------------------------------------#
    def run(self):
        """
        run runs on function start
        """

        self.startAuthserver()
        self.data = ''
        self.socket = self.context.socket(zmq.REP)
        self.socket.plain_server = True
        self.socket.bind("tcp://*:"+self.port)
        self.shutdown_event.clear()
        self.controller = CombaController(self, self.lqs_socket, self.lqs_recorder_socket)
        self.controller.messenger.setMailAddresses(self.get('frommail'), self.get('adminmail'))
        self.can_send = False
        # Process tasks forever
        while not self.shutdown_event.is_set():
            self.data = self.socket.recv()
            self.can_send = True
            data = self.data.split(' ')
            command = str(data.pop(0)) 
            params = "()" if len(data) < 1 else  "('" + "','".join(data) + "')" 
                     
            try: 
                exec"a=self.controller." + command + params  
            
            except SyntaxError:                
                self.controller.message('Warning: Syntax Error')

            except AttributeError:
                print "Warning: Method " + command + " does not exist"
                self.controller.message('Warning: Method ' + command + ' does not exist')
            except TypeError:
                print "Warning: Wrong number of params"
                self.controller.message('Warning: Wrong number of params')
            except:
                print "Warning: Unknown Error"
                self.controller.message('Warning: Unknown Error')

        return

    #------------------------------------------------------------------------------------------#
    def halt(self):
        """
        Stop the server
        """
        if self.shutdown_event.is_set():
            return
        try:
            del self.controller
        except:
            pass
        self.shutdown_event.set()
        result = 'failed'
        try:
            result = self.socket.unbind("tcp://*:"+self.port)
        except:
            pass
        #self.socket.close()

    #------------------------------------------------------------------------------------------#
    def reload(self):
        """
        stop, reload config and startagaing
        """
        if self.shutdown_event.is_set():
            return
        self.loadConfig()
        self.halt()
        time.sleep(3)
        self.run()

    #------------------------------------------------------------------------------------------#
    def send(self,message):
        """
        Send a message to the client
        :param message: string
        """
        if self.can_send:
            self.socket.send(message)
            self.can_send = False

    #------------------------------------------------------------------------------------------#
    def startAuthserver(self):
        """
        Start zmq authentification server
        """
        # stop auth server if running
        if self.authserver.is_alive():
            self.authserver.stop()
        if self.securitylevel > 0:
            # Authentifizierungsserver starten.

            self.authserver.start()

            # Bei security level 2 auch passwort und usernamen verlangen
            if self.securitylevel > 1:
                try:

                    addresses = CombaWhitelist().getList()
                    for address in addresses:
                        self.authserver.allow(address)

                except:
                    pass

            # Instruct authenticator to handle PLAIN requests
            self.authserver.configure_plain(domain='*', passwords=self.getAccounts())

    #------------------------------------------------------------------------------------------#
    def getAccounts(self):
        """
        Get accounts from redis db
        :return: llist - a list of accounts
        """
        accounts = CombaUser().getLogins()
        db = redis.Redis()

        internaccount = db.get('internAccess')
        if not internaccount:
            user = ''.join(random.sample(string.lowercase,10))
            password = ''.join(random.sample(string.lowercase+string.uppercase+string.digits,22))
            db.set('internAccess', user + ':' + password)
            intern = [user, password]
        else:
            intern =  internaccount.split(':')

        accounts[intern[0]] = intern[1]

        return accounts
예제 #43
0
class ZmqReceiver(object):
    def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None):
        self.context = zmq.Context()
        self.auth = None
        self.last_received_message = None
        self.is_running = False
        self.thread = None
        self.zmq_rep_bind_address = zmq_rep_bind_address
        self.zmq_sub_connect_addresses = zmq_sub_connect_addresses
        self.poller = zmq.Poller()
        self.sub_sockets = []
        self.rep_socket = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I (probably because the more secure solutions
            # require two way communication as well)
            self.auth = ThreadAuthenticator(self.context)
            self.auth.start()
            # Instruct authenticator to handle PLAIN requests
            self.auth.configure_plain(domain='*', passwords={username: password})

        if self.zmq_sub_connect_addresses:
            for address in self.zmq_sub_connect_addresses:
                self.sub_sockets.append(SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec))
        if zmq_rep_bind_address:
            self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth)

    # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds
    def stop(self):
        self.is_running = False
        logger.info("Closing pub and sub sockets...")
        if self.auth is not None:
            self.auth.stop()

    def run(self):
        self.is_running = True

        while self.is_running:
            socks = dict(self.poller.poll(1000))
            logger.debug("Poll cycle over. checking sockets")
            if self.rep_socket:
                incoming_message = self.rep_socket.recv_string(socks)
                if incoming_message is not None:
                    self.last_received_message = incoming_message
                    try:
                        logger.debug("Got info from REP socket")
                        response_message = self.handle_incoming_message(incoming_message)
                        self.rep_socket.send(response_message)
                    except Exception as e:
                        logger.error(e)
            for sub_socket in self.sub_sockets:
                incoming_message = sub_socket.recv_string(socks)
                if incoming_message is not None:
                    if incoming_message != "zmq_sub_heartbeat":
                        self.last_received_message = incoming_message
                    logger.debug("Got info from SUB socket")
                    try:
                        self.handle_incoming_message(incoming_message)
                    except Exception as e:
                        logger.error(e)

        if self.rep_socket:
            self.rep_socket.destroy()
        for sub_socket in self.sub_sockets:
            sub_socket.destroy()

    def create_response_message(self, status_code, status_message, response_message):
        if response_message is not None:
            return json.dumps({"status_code": status_code, "status_message": status_message, "response_message": response_message})
        else:
            return json.dumps({"status_code": status_code, "status_message": status_message})

    def handle_incoming_message(self, message):
        if message != "zmq_sub_heartbeat":
            return self.create_response_message(200, "OK", None)
예제 #44
0
class TaskQueue:
    """Outgoing task queue from the executor to the Interchange"""
    def __init__(
            self,
            address: str,
            port: int = 55001,
            identity: str = str(uuid.uuid4()),
            zmq_context=None,
            set_hwm=False,
            RCVTIMEO=None,
            SNDTIMEO=None,
            linger=None,
            ironhouse: bool = False,
            keys_dir: str = os.path.abspath(".curve"),
            mode: str = "client",
    ):
        """
        Parameters
        ----------

        address: str
           address to connect

        port: int
           Port to use

        identity : str
           Applies only to clients, where the identity must match the endpoint uuid.
           This will be utf-8 encoded on the wire. A random uuid4 string is set by
           default.

        mode: string
           Either 'client' or 'server'

        keys_dir : string
           Directory from which keys will be loaded for curve.

        ironhouse: Bool
           Only valid for server mode. Setting this flag switches the server to require
           client keys to be available on the server in the keys_dir.
        """
        if zmq_context:
            self.context = zmq_context
        else:
            self.context = zmq.Context()

        self.mode = mode
        self.port = port
        self.ironhouse = ironhouse
        self.keys_dir = keys_dir

        assert self.mode in [
            "client",
            "server",
        ], "Only two modes are supported: client, server"

        if self.mode == "server":
            print("Configuring server")
            self.zmq_socket = self.context.socket(zmq.ROUTER)
            self.zmq_socket.set(zmq.ROUTER_MANDATORY, 1)
            self.zmq_socket.set(zmq.ROUTER_HANDOVER, 1)
            print("Setting up auth-server")
            self.setup_server_auth()
        elif self.mode == "client":
            self.zmq_socket = self.context.socket(zmq.DEALER)
            self.setup_client_auth()
            self.zmq_socket.setsockopt(zmq.IDENTITY, identity.encode("utf-8"))
        else:
            raise ValueError(
                "TaskQueue must be initialized with mode set to 'server' or 'client'"
            )

        if set_hwm:
            self.zmq_socket.set_hwm(0)
        if RCVTIMEO is not None:
            self.zmq_socket.setsockopt(zmq.RCVTIMEO, RCVTIMEO)
        if SNDTIMEO is not None:
            self.zmq_socket.setsockopt(zmq.SNDTIMEO, SNDTIMEO)
        if linger is not None:
            self.zmq_socket.setsockopt(zmq.LINGER, linger)

        # all zmq setsockopt calls must be done before bind/connect is called
        if self.mode == "server":
            self.zmq_socket.bind(f"tcp://*:{port}")
        elif self.mode == "client":
            self.zmq_socket.connect(f"tcp://{address}:{port}")

        self.poller = zmq.Poller()
        self.poller.register(self.zmq_socket)
        os.makedirs(self.keys_dir, exist_ok=True)
        log.debug(f"Initializing Taskqueue:{self.mode} on port:{self.port}")

    def zmq_context(self):
        return self.context

    def add_client_key(self, endpoint_id, client_key):
        log.info("Adding client key")
        if self.ironhouse:
            # Use the ironhouse ZMQ pattern: http://hintjens.com/blog:49#toc6
            with open(os.path.join(self.keys_dir, f"{endpoint_id}.key"),
                      "w") as f:
                f.write(client_key)
            try:
                self.auth.configure_curve(domain="*", location=self.keys_dir)
            except Exception:
                log.exception("Failed to load keys from {self.keys_dir}")
        return

    def setup_server_auth(self):
        # Start an authenticator for this context.
        self.auth = ThreadAuthenticator(self.context)
        self.auth.start()
        self.auth.allow("127.0.0.1")
        # Tell the authenticator how to handle CURVE requests

        if not self.ironhouse:
            # Use the stonehouse ZMQ pattern: http://hintjens.com/blog:49#toc5
            self.auth.configure_curve(domain="*",
                                      location=zmq.auth.CURVE_ALLOW_ANY)

        server_secret_file = os.path.join(self.keys_dir, "server.key_secret")
        server_public, server_secret = zmq.auth.load_certificate(
            server_secret_file)
        self.zmq_socket.curve_secretkey = server_secret
        self.zmq_socket.curve_publickey = server_public
        self.zmq_socket.curve_server = True  # must come before bind

    def setup_client_auth(self):
        # We need two certificates, one for the client and one for
        # the server. The client must know the server's public key
        # to make a CURVE connection.
        client_secret_file = os.path.join(self.keys_dir, "endpoint.key_secret")
        client_public, client_secret = zmq.auth.load_certificate(
            client_secret_file)
        self.zmq_socket.curve_secretkey = client_secret
        self.zmq_socket.curve_publickey = client_public

        # The client must know the server's public key to make a CURVE connection.
        server_public_file = os.path.join(self.keys_dir, "server.key")
        server_public, _ = zmq.auth.load_certificate(server_public_file)
        self.zmq_socket.curve_serverkey = server_public

    def get(self, block=True, timeout=1000):
        """
        Parameters
        ----------

        block : Bool
            Blocks until there's a message, Default is True
        timeout : int
            Milliseconds to wait.
        """
        # timeout is in milliseconds
        if block is True:
            return self.zmq_socket.recv_multipart()

        socks = dict(self.poller.poll(timeout=timeout))
        if self.zmq_socket in socks and socks[self.zmq_socket] == zmq.POLLIN:
            message = self.zmq_socket.recv_multipart()
            return message
        else:
            raise zmq.Again

    def register_client(self, message):
        return self.zmq_socket.send_multipart([message])

    def put(self, dest, message, max_timeout=1000):
        """This function needs to be fast at the same time aware of the possibility of
        ZMQ pipes overflowing.

        The timeout increases slowly if contention is detected on ZMQ pipes.
        We could set copy=False and get slightly better latency but this results
        in ZMQ sockets reaching a broken state once there are ~10k tasks in flight.
        This issue can be magnified if each the serialized buffer itself is larger.

        Parameters
        ----------

        dest : zmq_identity of the destination endpoint, must be a byte string

        message : py object
             Python object to send

        max_timeout : int
             Max timeout in milliseconds that we will wait for before raising an
             exception

        Raises
        ------

        zmq.EAGAIN if the send failed.
        zmq.error.ZMQError: Host unreachable (if client disconnects?)

        """
        if self.mode == "client":
            return self.zmq_socket.send_multipart([message])
        else:
            return self.zmq_socket.send_multipart([dest, message])

    def close(self):
        self.zmq_socket.close()
        self.context.term()
예제 #45
0
def main():
    """
    Runs SEND either in transmitter or receiver mode

    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-t",
        "--transmit",
        action="store_true",
        help="Flag indicating that user will be transmitting files"
    )
    parser.add_argument(
        "-r",
        "--receive",
        action="store_true",
        help="Flag indicating that user will be receiving files"
    )
    parser.add_argument(
        "--location",
        help="Location of files to send/receive. Can be a specific file if tx."
    )
    parser.add_argument(
        "--ip",
        help="IP Address to form connection with"
    )
    parser.add_argument(
        "--port",
        nargs='?',
        const=6000,
        default=6000,
        type=int,
        help="Port to form connection with (only needed if using non-default)"
    )
    parser.add_argument(
        "--public_key",
        nargs='?',
        help="Public Key of transmitter in plain-text (only needed if receiver)"
    )
    args=parser.parse_args()

    # Security Authentication Thread
    _generate_security_keys()
    authenticator = ThreadAuthenticator(manager.ctx)
    authenticator.start()
    whitelist = [
        "127.0.0.1",
        args.ip
    ]
    authenticator.allow(*whitelist)
    authenticator.configure_curve(domain="*", location=PUBKEYS)

    try:
        if args.transmit:
            thread = manager.publish_folder(
                args.port,
                args.location
            )

        elif args.receive:
            thread = manager.subscribe_folder(
                args.ip,
                args.port,
                args.location,
                args.public_key
            )
        else:
            raise ValueError(f"User did not specify transmit/receive")

    except (OSError, ValueError):
        raise

    except KeyboardInterrupt:
        pass

    finally:
        # Keep things rolling until the transfer is done or the thread dies from
        # timing out
        while thread.isAlive():
            pass
        thread.join()
        # Clean up and close everything out
        authenticator.stop()
        # Use destroy versus term: https://github.com/zeromq/pyzmq/issues/991
        manager.ctx.destroy()
예제 #46
0
class Command(LAVADaemonCommand):
    """
    worker_host is the hostname of the worker this field is set by the admin
    and could therefore be empty in a misconfigured instance.
    """
    logger = None
    help = "LAVA dispatcher master"
    default_logfile = "/var/log/lava-server/lava-master.log"

    def __init__(self, *args, **options):
        super().__init__(*args, **options)
        self.auth = None
        self.controler = None
        self.event_socket = None
        self.poller = None
        self.pipe_r = None
        self.inotify_fd = None
        # List of logs
        # List of known dispatchers. At startup do not load this from the
        # database. This will help to know if the slave as restarted or not.
        self.dispatchers = {
            "lava-logs": SlaveDispatcher("lava-logs", online=False)
        }
        self.events = {"canceling": set(), "available_dt": set()}

    def add_arguments(self, parser):
        super().add_arguments(parser)
        net = parser.add_argument_group("network")
        net.add_argument(
            '--master-socket',
            default='tcp://*:5556',
            help="Socket for master-slave communication. Default: tcp://*:5556"
        )
        net.add_argument('--event-url',
                         default="tcp://localhost:5500",
                         help="URL of the publisher")
        net.add_argument('--ipv6',
                         default=False,
                         action='store_true',
                         help="Enable IPv6 on the listening sockets")
        net.add_argument('--encrypt',
                         default=False,
                         action='store_true',
                         help="Encrypt messages")
        net.add_argument(
            '--master-cert',
            default='/etc/lava-dispatcher/certificates.d/master.key_secret',
            help="Certificate for the master socket")
        net.add_argument('--slaves-certs',
                         default='/etc/lava-dispatcher/certificates.d',
                         help="Directory for slaves certificates")

    def send_status(self, hostname):
        """
        The master crashed, send a STATUS message to get the current state of jobs
        """
        jobs = TestJob.objects.filter(
            actual_device__worker_host__hostname=hostname,
            state=TestJob.STATE_RUNNING)
        for job in jobs:
            self.logger.info("[%d] STATUS => %s (%s)", job.id, hostname,
                             job.actual_device.hostname)
            send_multipart_u(self.controler, [hostname, 'STATUS', str(job.id)])

    def dispatcher_alive(self, hostname):
        if hostname not in self.dispatchers:
            # The server crashed: send a STATUS message
            self.logger.warning("Unknown dispatcher <%s> (server crashed)",
                                hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)
            self.send_status(hostname)

        # Mark the dispatcher as alive
        self.dispatchers[hostname].alive()

    def controler_socket(self):
        try:
            # We need here to use the zmq.NOBLOCK flag, otherwise we could block
            # the whole main loop where this function is called.
            msg = self.controler.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False
        # This is way to verbose for production and should only be activated
        # by (and for) developers
        # self.logger.debug("[CC] Receiving: %s", msg)

        # 1: the hostname (see ZMQ documentation)
        hostname = u(msg[0])
        # 2: the action
        action = u(msg[1])

        # Check that lava-logs only send PINGs
        if hostname == "lava-logs" and action != "PING":
            self.logger.error("%s => %s Invalid action from log daemon",
                              hostname, action)
            return True

        # Handle the actions
        if action == 'HELLO' or action == 'HELLO_RETRY':
            self._handle_hello(hostname, action, msg)
        elif action == 'PING':
            self._handle_ping(hostname, action, msg)
        elif action == 'END':
            self._handle_end(hostname, action, msg)
        elif action == 'START_OK':
            self._handle_start_ok(hostname, action, msg)
        else:
            self.logger.error("<%s> sent unknown action=%s, args=(%s)",
                              hostname, action, msg[1:])
        return True

    def read_event_socket(self):
        try:
            msg = self.event_socket.recv_multipart(zmq.NOBLOCK)
        except zmq.error.Again:
            return False

        try:
            (topic, _, dt, username, data) = (u(m) for m in msg)
            data = simplejson.loads(data)
        except ValueError:
            self.logger.error("Invalid event: %s", msg)
            return True

        if topic.endswith(".testjob"):
            if data["state"] == "Canceling":
                self.events["canceling"].add(int(data["job"]))
            elif data["state"] == "Submitted":
                if "device_type" in data:
                    self.events["available_dt"].add(data["device_type"])
        elif topic.endswith(".device"):
            if data["state"] == "Idle" and data["health"] in [
                    "Good", "Unknown", "Looping"
            ]:
                self.events["available_dt"].add(data["device_type"])

        return True

    def _handle_end(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
            error_msg = msg[3]
            compressed_description = msg[4]
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        try:
            job = TestJob.objects.get(id=job_id)
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
            # ACK even if the job is unknown to let the dispatcher
            # forget about it
            send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
            return

        filename = os.path.join(job.output_dir, 'description.yaml')
        # If description.yaml already exists: a END was already received
        if os.path.exists(filename):
            self.logger.info("[%d] %s => END (duplicated), skipping", job_id,
                             hostname)
        else:
            if compressed_description:
                self.logger.info("[%d] %s => END", job_id, hostname)
            else:
                self.logger.info(
                    "[%d] %s => END (lava-run crashed, mark job as INCOMPLETE)",
                    job_id, hostname)
                with transaction.atomic():
                    # TODO: find a way to lock actual_device
                    job = TestJob.objects.select_for_update() \
                                         .get(id=job_id)

                    job.go_state_finished(TestJob.HEALTH_INCOMPLETE)
                    if error_msg:
                        self.logger.error("[%d] Error: %s", job_id, error_msg)
                        job.failure_comment = error_msg
                    job.save()

            # Create description.yaml even if it's empty
            # Allows to know when END messages are duplicated
            try:
                # Create the directory if it was not already created
                mkdir(os.path.dirname(filename))
                # TODO: check that compressed_description is not ""
                description = lzma.decompress(compressed_description)
                with open(filename, 'w') as f_description:
                    f_description.write(description.decode("utf-8"))
                if description:
                    parse_job_description(job)
            except (OSError, lzma.LZMAError) as exc:
                self.logger.error("[%d] Unable to dump 'description.yaml'",
                                  job_id)
                self.logger.exception("[%d] %s", job_id, exc)

        # ACK the job and mark the dispatcher as alive
        send_multipart_u(self.controler, [hostname, 'END_OK', str(job_id)])
        self.dispatcher_alive(hostname)

    def _handle_hello(self, hostname, action, msg):
        # Check the protocol version
        try:
            slave_version = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return

        self.logger.info("%s => %s", hostname, action)
        if slave_version != PROTOCOL_VERSION:
            self.logger.error(
                "<%s> using protocol v%d while master is using v%d", hostname,
                slave_version, PROTOCOL_VERSION)
            return

        send_multipart_u(self.controler, [hostname, 'HELLO_OK'])
        # If the dispatcher is known and sent an HELLO, means that
        # the slave has restarted
        if hostname in self.dispatchers:
            if action == 'HELLO':
                self.logger.warning("Dispatcher <%s> has RESTARTED", hostname)
            else:
                # Assume the HELLO command was received, and the
                # action succeeded.
                self.logger.warning("Dispatcher <%s> was not confirmed",
                                    hostname)
        else:
            # No dispatcher, treat HELLO and HELLO_RETRY as a normal HELLO
            # message.
            self.logger.warning("New dispatcher <%s>", hostname)
            self.dispatchers[hostname] = SlaveDispatcher(hostname)

        # Mark the dispatcher as alive
        self.dispatcher_alive(hostname)

    def _handle_ping(self, hostname, action, msg):  # pylint: disable=unused-argument
        self.logger.debug("%s => PING(%d)", hostname, PING_INTERVAL)
        # Send back a signal
        send_multipart_u(
            self.controler,
            [hostname, 'PONG', str(PING_INTERVAL)])
        self.dispatcher_alive(hostname)

    def _handle_start_ok(self, hostname, action, msg):  # pylint: disable=unused-argument
        try:
            job_id = int(msg[2])
        except (IndexError, ValueError):
            self.logger.error("Invalid message from <%s> '%s'", hostname, msg)
            return
        self.logger.info("[%d] %s => START_OK", job_id, hostname)
        try:
            with transaction.atomic():
                # TODO: find a way to lock actual_device
                job = TestJob.objects.select_for_update() \
                                     .get(id=job_id)
                job.go_state_running()
                job.save()
        except TestJob.DoesNotExist:
            self.logger.error("[%d] Unknown job", job_id)
        else:
            self.dispatcher_alive(hostname)

    def export_definition(self, job):  # pylint: disable=no-self-use
        job_def = yaml.safe_load(job.definition)
        job_def['compatibility'] = job.pipeline_compatibility

        # no need for the dispatcher to retain comments
        return yaml.dump(job_def)

    def save_job_config(self, job, device_cfg, env_str, env_dut_str,
                        dispatcher_cfg):
        output_dir = job.output_dir
        mkdir(output_dir)
        with open(os.path.join(output_dir, "job.yaml"), "w") as f_out:
            f_out.write(self.export_definition(job))
        with open(os.path.join(output_dir, "device.yaml"), "w") as f_out:
            yaml.dump(device_cfg, f_out)
        if env_str:
            with open(os.path.join(output_dir, "env.yaml"), "w") as f_out:
                f_out.write(env_str)
        if env_dut_str:
            with open(os.path.join(output_dir, "env.dut.yaml"), "w") as f_out:
                f_out.write(env_dut_str)
        if dispatcher_cfg:
            with open(os.path.join(output_dir, "dispatcher.yaml"),
                      "w") as f_out:
                f_out.write(dispatcher_cfg)

    def start_job(self, job):
        # Load job definition to get the variables for template
        # rendering
        job_def = yaml.safe_load(job.definition)
        job_ctx = job_def.get('context', {})

        device = job.actual_device
        worker = device.worker_host

        # TODO: check that device_cfg is not None!
        device_cfg = device.load_configuration(job_ctx)

        # Try to load the dispatcher specific files and then fallback to the
        # default configuration files.
        env_str = load_optional_yaml_file(
            os.path.join(DISPATCHERS_PATH, worker.hostname, "env.yaml"),
            ENV_PATH)
        env_dut_str = load_optional_yaml_file(
            os.path.join(DISPATCHERS_PATH, worker.hostname, "env.dut.yaml"),
            ENV_DUT_PATH)
        dispatcher_cfg = load_optional_yaml_file(
            os.path.join(DISPATCHERS_PATH, worker.hostname, "dispatcher.yaml"),
            os.path.join(DISPATCHERS_PATH, "%s.yaml" % worker.hostname))

        self.save_job_config(job, device_cfg, env_str, env_dut_str,
                             dispatcher_cfg)
        self.logger.info("[%d] START => %s (%s)", job.id, worker.hostname,
                         device.hostname)
        send_multipart_u(self.controler, [
            worker.hostname, 'START',
            str(job.id),
            self.export_definition(job),
            yaml.dump(device_cfg), dispatcher_cfg, env_str, env_dut_str
        ])

        # For multinode jobs, start the dynamic connections
        parent = job
        for sub_job in job.sub_jobs_list:
            if sub_job == parent or not sub_job.dynamic_connection:
                continue

            # inherit only enough configuration for dynamic_connection operation
            self.logger.info(
                "[%d] Trimming dynamic connection device configuration.",
                sub_job.id)
            min_device_cfg = parent.actual_device.minimise_configuration(
                device_cfg)

            self.save_job_config(sub_job, min_device_cfg, env_str, env_dut_str,
                                 dispatcher_cfg)
            self.logger.info("[%d] START => %s (connection)", sub_job.id,
                             worker.hostname)
            send_multipart_u(self.controler, [
                worker.hostname, 'START',
                str(sub_job.id),
                self.export_definition(sub_job),
                yaml.dump(min_device_cfg), dispatcher_cfg, env_str, env_dut_str
            ])

    def start_jobs(self, jobs=None):
        """
        Loop on all scheduled jobs and send the START message to the slave.
        """
        # make the request atomic
        query = TestJob.objects.select_for_update()
        # Only select test job that are ready
        query = query.filter(state=TestJob.STATE_SCHEDULED)
        # Only start jobs on online workers
        query = query.filter(
            actual_device__worker_host__state=Worker.STATE_ONLINE)
        # exclude test job without a device: they are special test jobs like
        # dynamic connection.
        query = query.exclude(actual_device=None)
        # Allow for partial scheduling
        if jobs is not None:
            query = query.filter(id__in=jobs)

        # Loop on all jobs
        for job in query:
            msg = None
            try:
                self.start_job(job)
            except jinja2.TemplateNotFound as exc:
                self.logger.error("[%d] Template not found: '%s'", job.id,
                                  exc.message)
                msg = "Template not found: '%s'" % exc.message
            except jinja2.TemplateSyntaxError as exc:
                self.logger.error(
                    "[%d] Template syntax error in '%s', line %d: %s", job.id,
                    exc.name, exc.lineno, exc.message)
                msg = "Template syntax error in '%s', line %d: %s" % (
                    exc.name, exc.lineno, exc.message)
            except OSError as exc:
                self.logger.error("[%d] Unable to read '%s': %s", job.id,
                                  exc.filename, exc.strerror)
                msg = "Cannot open '%s': %s" % (exc.filename, exc.strerror)
            except yaml.YAMLError as exc:
                self.logger.error("[%d] Unable to parse job definition: %s",
                                  job.id, exc)
                msg = "Cannot parse job definition: %s" % exc

            if msg:
                # Add the error as lava.job result
                metadata = {
                    "case": "job",
                    "definition": "lava",
                    "error_type": "Infrastructure",
                    "error_msg": msg,
                    "result": "fail"
                }
                suite, _ = TestSuite.objects.get_or_create(name="lava",
                                                           job=job)
                TestCase.objects.create(name="job",
                                        suite=suite,
                                        result=TestCase.RESULT_FAIL,
                                        metadata=yaml.dump(metadata))
                job.go_state_finished(TestJob.HEALTH_INCOMPLETE, True)
                job.save()

    def cancel_jobs(self, partial=False):
        # make the request atomic
        query = TestJob.objects.select_for_update()
        # Only select the test job that are canceling
        query = query.filter(state=TestJob.STATE_CANCELING)
        # Only cancel jobs on online workers
        query = query.filter(
            actual_device__worker_host__state=Worker.STATE_ONLINE)

        # Allow for partial canceling
        if partial:
            query = query.filter(id__in=list(self.events["canceling"]))

        # Loop on all jobs
        for job in query:
            worker = job.lookup_worker if job.dynamic_connection else job.actual_device.worker_host
            self.logger.info("[%d] CANCEL => %s", job.id, worker.hostname)
            send_multipart_u(self.controler,
                             [worker.hostname, 'CANCEL',
                              str(job.id)])

    def handle(self, *args, **options):
        # Initialize logging.
        self.setup_logging("lava-master", options["level"],
                           options["log_file"], FORMAT)

        self.logger.info("[INIT] Dropping privileges")
        if not self.drop_privileges(options['user'], options['group']):
            self.logger.error("[INIT] Unable to drop privileges")
            return

        filename = os.path.join(settings.MEDIA_ROOT, 'lava-master-config.yaml')
        self.logger.debug("[INIT] Dumping config to %s", filename)
        with open(filename, 'w') as output:
            yaml.dump(options, output)

        self.logger.info("[INIT] Marking all workers as offline")
        with transaction.atomic():
            for worker in Worker.objects.select_for_update().all():
                worker.go_state_offline()
                worker.save()

        # Create the sockets
        context = zmq.Context()
        self.controler = context.socket(zmq.ROUTER)
        self.event_socket = context.socket(zmq.SUB)

        if options['ipv6']:
            self.logger.info("[INIT] Enabling IPv6")
            self.controler.setsockopt(zmq.IPV6, 1)
            self.event_socket.setsockopt(zmq.IPV6, 1)

        if options['encrypt']:
            self.logger.info("[INIT] Starting encryption")
            try:
                self.auth = ThreadAuthenticator(context)
                self.auth.start()
                self.logger.debug("[INIT] Opening master certificate: %s",
                                  options['master_cert'])
                master_public, master_secret = zmq.auth.load_certificate(
                    options['master_cert'])
                self.logger.debug("[INIT] Using slaves certificates from: %s",
                                  options['slaves_certs'])
                self.auth.configure_curve(domain='*',
                                          location=options['slaves_certs'])
            except OSError as err:
                self.logger.error(err)
                self.auth.stop()
                return
            self.controler.curve_publickey = master_public
            self.controler.curve_secretkey = master_secret
            self.controler.curve_server = True

            self.logger.debug("[INIT] Watching %s", options["slaves_certs"])
            self.inotify_fd = watch_directory(options["slaves_certs"])
            if self.inotify_fd is None:
                self.logger.error("[INIT] Unable to start inotify")

        self.controler.setsockopt(zmq.IDENTITY, b"master")
        # From http://api.zeromq.org/4-2:zmq-setsockopt#toc42
        # "If two clients use the same identity when connecting to a ROUTER
        # [...] the ROUTER socket shall hand-over the connection to the new
        # client and disconnect the existing one."
        self.controler.setsockopt(zmq.ROUTER_HANDOVER, 1)
        self.controler.bind(options['master_socket'])

        self.event_socket.setsockopt(zmq.SUBSCRIBE, b(settings.EVENT_TOPIC))
        self.event_socket.connect(options['event_url'])

        # Poll on the sockets. This allow to have a
        # nice timeout along with polling.
        self.poller = zmq.Poller()
        self.poller.register(self.controler, zmq.POLLIN)
        self.poller.register(self.event_socket, zmq.POLLIN)
        if self.inotify_fd is not None:
            self.poller.register(os.fdopen(self.inotify_fd), zmq.POLLIN)

        # Translate signals into zmq messages
        (self.pipe_r, _) = self.setup_zmq_signal_handler()
        self.poller.register(self.pipe_r, zmq.POLLIN)

        self.logger.info("[INIT] LAVA master has started.")
        self.logger.info("[INIT] Using protocol version %d", PROTOCOL_VERSION)

        try:
            self.main_loop(options)
        except BaseException as exc:
            self.logger.error("[CLOSE] Unknown exception raised, leaving!")
            self.logger.exception(exc)
        finally:
            # Drop controler socket: the protocol does handle lost messages
            self.logger.info(
                "[CLOSE] Closing the controler socket and dropping messages")
            self.controler.close(linger=0)
            self.event_socket.close(linger=0)
            if options['encrypt']:
                self.auth.stop()
            context.term()

    def main_loop(self, options):
        last_schedule = last_dispatcher_check = time.time()

        while True:
            try:
                try:
                    # Compute the timeout
                    now = time.time()
                    timeout = min(
                        SCHEDULE_INTERVAL - (now - last_schedule),
                        PING_INTERVAL - (now - last_dispatcher_check))
                    # If some actions are remaining, decrease the timeout
                    if any([self.events[k] for k in self.events.keys()]):
                        timeout = min(timeout, 2)
                    # Wait at least for 1ms
                    timeout = max(timeout * 1000, 1)

                    # Wait for data or a timeout
                    sockets = dict(self.poller.poll(timeout))
                except zmq.error.ZMQError:
                    continue

                if sockets.get(self.pipe_r) == zmq.POLLIN:
                    self.logger.info("[POLL] Received a signal, leaving")
                    break

                # Command socket
                if sockets.get(self.controler) == zmq.POLLIN:
                    while self.controler_socket(
                    ):  # Unqueue all pending messages
                        pass

                # Events socket
                if sockets.get(self.event_socket) == zmq.POLLIN:
                    while self.read_event_socket(
                    ):  # Unqueue all pending messages
                        pass
                    # Wait for the next iteration to handle the event.
                    # In fact, the code that generated the event (lava-logs or
                    # lava-server-gunicorn) needs some time to commit the
                    # database transaction.
                    # If we are too fast, the database object won't be
                    # available (or in the right state) yet.
                    continue

                # Inotify socket
                if sockets.get(self.inotify_fd) == zmq.POLLIN:
                    os.read(self.inotify_fd, 4096)
                    self.logger.debug("[AUTH] Reloading certificates from %s",
                                      options['slaves_certs'])
                    self.auth.configure_curve(domain='*',
                                              location=options['slaves_certs'])

                # Check dispatchers status
                now = time.time()
                if now - last_dispatcher_check > PING_INTERVAL:
                    for hostname, dispatcher in self.dispatchers.items():
                        if dispatcher.online and now - dispatcher.last_msg > DISPATCHER_TIMEOUT:
                            if hostname == "lava-logs":
                                self.logger.error(
                                    "[STATE] lava-logs goes OFFLINE")
                            else:
                                self.logger.error(
                                    "[STATE] Dispatcher <%s> goes OFFLINE",
                                    hostname)
                            self.dispatchers[hostname].go_offline()
                    last_dispatcher_check = now

                # Limit accesses to the database. This will also limit the rate of
                # CANCEL and START messages
                if time.time() - last_schedule > SCHEDULE_INTERVAL:
                    if self.dispatchers["lava-logs"].online:
                        schedule(self.logger)

                        # Dispatch scheduled jobs
                        with transaction.atomic():
                            self.start_jobs()
                    else:
                        self.logger.warning(
                            "lava-logs is offline: can't schedule jobs")

                    # Handle canceling jobs
                    with transaction.atomic():
                        self.cancel_jobs()

                    # Do not count the time taken to schedule jobs
                    last_schedule = time.time()
                else:
                    # Cancel the jobs and remove the jobs from the set
                    if self.events["canceling"]:
                        with transaction.atomic():
                            self.cancel_jobs(partial=True)
                        self.events["canceling"] = set()
                    # Schedule for available device-types
                    if self.events["available_dt"]:
                        jobs = schedule(self.logger,
                                        self.events["available_dt"])
                        self.events["available_dt"] = set()
                        # Dispatch scheduled jobs
                        with transaction.atomic():
                            self.start_jobs(jobs)

            except (OperationalError, InterfaceError):
                self.logger.info("[RESET] database connection reset.")
                # Closing the database connection will force Django to reopen
                # the connection
                connection.close()
                time.sleep(2)
예제 #47
0
class Actor(object):
    '''The actor class implements all the management and control functions over its components
    
    :param gModel: the JSON-based dictionary holding the model for the app this actor belongs to.
    :type gModel: dict
    :param gModelName: the name of the top-level model for the app
    :type gModelName: str
    :param aName: name of the actor. It is an index into the gModel that points to the part of the model specific to the actor
    :type aName: str
    :param sysArgv: list of arguments for the actor: -key1 value1 -key2 value2 ...
    :type list:
         
    '''
    def __init__(self, gModel, gModelName, aName, sysArgv):
        '''
        Constructor
        '''
        self.logger = logging.getLogger(__name__)
        self.inst_ = self
        self.appName = gModel["name"]
        self.modelName = gModelName
        self.name = aName
        self.pid = os.getpid()
        self.uuid = None
        self.setupIfaces()
        # Assumption : pid is a 4 byte int
        self.actorID = ipaddress.IPv4Address(
            self.globalHost).packed + self.pid.to_bytes(4, 'big')
        self.suffix = ""
        if aName not in gModel["actors"]:
            raise BuildError('Actor "%s" unknown' % aName)
        self.model = gModel["actors"][
            aName]  # Fetch the relevant content from the model

        self.INT_RE = re.compile(r"^[-]?\d+$")
        self.parseParams(sysArgv)

        # Use czmq's context
        czmq_ctx = Zsys.init()
        self.context = zmq.Context.shadow(czmq_ctx.value)
        Zsys.handler_reset()  # Reset previous signal handler

        # Context for app sockets
        self.appContext = zmq.Context()

        if Config.SECURITY:
            (self.public_key,
             self.private_key) = zmq.auth.load_certificate(const.appCertFile)
            hosts = ['127.0.0.1']
            try:
                with open(const.appDescFile, 'r') as f:
                    content = yaml.load(f)
                    hosts += content.hosts
            except:
                pass

            self.auth = ThreadAuthenticator(self.appContext)
            self.auth.start()
            self.auth.allow(*hosts)
            self.auth.configure_curve(domain='*',
                                      location=zmq.auth.CURVE_ALLOW_ANY)
        else:
            (self.public_key, self.private_key) = (None, None)
            self.auth = None
            self.appContext = self.context

        try:
            if os.path.isfile(const.logConfFile) and os.access(
                    const.logConfFile, os.R_OK):
                spdlog_setup.from_file(const.logConfFile)
        except Exception as e:
            self.logger.error("error while configuring componentLogger: %s" %
                              repr(e))

        messages = gModel[
            "messages"]  # Global message types (global on the network)
        self.messageNames = []
        for messageSpec in messages:
            self.messageNames.append(messageSpec["name"])

        locals_ = self.model[
            "locals"]  # Local message types (local to the host)
        self.localNames = []
        for messageSpec in locals_:
            self.localNames.append(messageSpec["type"])

        internals = self.model[
            "internals"]  # Internal message types (internal to the actor process)
        self.internalNames = []
        for messageSpec in internals:
            self.internalNames.append(messageSpec["type"])

        self.components = {}
        instSpecs = self.model["instances"]
        compSpecs = gModel["components"]
        ioSpecs = gModel["devices"]
        for instName in instSpecs:  # Create the component instances: the 'parts'
            instSpec = instSpecs[instName]
            instType = instSpec['type']
            if instType in compSpecs:
                typeSpec = compSpecs[instType]
                ioComp = False
            elif instType in ioSpecs:
                typeSpec = ioSpecs[instType]
                ioComp = True
            else:
                raise BuildError(
                    'Component type "%s" for instance "%s" is undefined' %
                    (instType, instName))
            instFormals = typeSpec['formals']
            instActuals = instSpec['actuals']
            instArgs = self.buildInstArgs(instName, instFormals, instActuals)

            # Check whether the component is C++ component
            ccComponentFile = 'lib' + instType.lower() + '.so'
            ccComp = os.path.isfile(ccComponentFile)
            try:
                if not ioComp:
                    if ccComp:
                        modObj = importlib.import_module('lib' +
                                                         instType.lower())
                        self.components[instName] = modObj.create_component_py(
                            self, self.model, typeSpec, instName, instType,
                            instArgs, self.appName, self.name)
                    else:
                        self.components[instName] = Part(
                            self, typeSpec, instName, instType, instArgs)
                else:
                    self.components[instName] = Peripheral(
                        self, typeSpec, instName, instType, instArgs)
            except Exception as e:
                traceback.print_exc()
                self.logger.error("Error while constructing part '%s.%s': %s" %
                                  (instType, instName, str(e)))

    def getParameterValueType(self, param, defaultType):
        ''' Infer the type of a parameter from its value unless a default type is provided. \
            In the latter case the parameter's value is converted to that type.
            
            :param param: a parameter value
            :type param: one of bool,int,float,str
            :param defaultType:
            :type defaultType: one of bool,int,float,str
            :return: a pair (value,type)
            :rtype: tuple
             
        '''
        paramValue, paramType = None, None
        if defaultType != None:
            if defaultType == str:
                paramValue, paramType = param, str
            elif defaultType == int:
                paramValue, paramType = int(param), int
            elif defaultType == float:
                paramValue, paramType = float(param), float
            elif defaultType == bool:
                paramType = bool
                paramValue = False if param == "False" else True if param == "True" else None
                paramValue, paramType = bool(param), float
        else:
            if param == 'True':
                paramValue, paramType = True, bool
            elif param == 'False':
                paramValue, paramType = True, bool
            elif self.INT_RE.match(param) is not None:
                paramValue, paramType = int(param), int
            else:
                try:
                    paramValue, paramType = float(param), float
                except:
                    paramValue, paramType = str(param), str
        return (paramValue, paramType)

    def parseParams(self, sysArgv):
        '''Parse actor arguments from the command line
        
        Compares the actual arguments to the formal arguments (from the model) and
        fills out the local parameter table accordingly. Generates a warning on 
        extra arguments and raises an exception on required but missing ones.
           
        '''
        self.params = {}
        formals = self.model["formals"]
        optList = []
        for formal in formals:
            key = formal["name"]
            default = None if "default" not in formal else formal["default"]
            self.params[key] = default
            optList.append("%s=" % key)
        try:
            opts, _args = getopt.getopt(sysArgv, '', optList)
        except:
            self.logger.info("Error parsing actor options %s" % str(sysArgv))
            return
        for opt in opts:
            optName2, optValue = opt
            optName = optName2[2:]  # Drop two leading dashes
            if optName in self.params:
                defaultType = None if self.params[optName] == None else type(
                    self.params[optName])
                paramValue, paramType = self.getParameterValueType(
                    optValue, defaultType)
                if self.params[optName] != None:
                    if paramType != type(self.params[optName]):
                        raise BuildError(
                            "Type of default value does not match type of argument %s"
                            % str((optName, optValue)))
                self.params[optName] = paramValue
            else:
                self.logger.info("Unknown argument %s - ignored" % optName)
        for param in self.params:
            if self.params[param] == None:
                raise BuildError("Required parameter %s missing" % param)

    def buildInstArgs(self, instName, formals, actuals):
        args = {}
        for formal in formals:
            argName = formal['name']
            argValue = None
            actual = next(
                (actual for actual in actuals if actual['name'] == argName),
                None)
            defaultValue = None
            if 'default' in formal:
                defaultValue = formal['default']
            if actual != None:
                assert (actual['name'] == argName)
                if 'param' in actual:
                    paramName = actual['param']
                    if paramName in self.params:
                        argValue = self.params[paramName]
                    else:
                        raise BuildError(
                            "Unspecified parameter %s referenced in %s" %
                            (paramName, instName))
                elif 'value' in actual:
                    argValue = actual['value']
                else:
                    raise BuildError("Actual parameter %s has no value" %
                                     argName)
            elif defaultValue != None:
                argValue = defaultValue
            else:
                raise BuildError("Argument %s in %s has no defined value" %
                                 (argName, instName))
            args[argName] = argValue
        return args

    def isLocalMessage(self, msgTypeName):
        '''Return True if the message type is local
        
        '''
        return msgTypeName in self.localNames

    def isInnerMessage(self, msgTypeName):
        '''Return True if the message type is internal
        
        '''
        return msgTypeName in self.internalNames

    def getLocalIface(self):
        '''Return the IP address of the host-local network interface (usually 127.0.0.1) 
        '''
        return self.localHost

    def getGlobalIface(self):
        '''Return the IP address of the global network interface
        '''
        return self.globalHost

    def getActorName(self):
        '''Return the name of this actor (as defined in the app model)
        '''
        return self.name

    def getAppName(self):
        '''Return the name of the app this actor belongs to
        '''
        return self.appName

    def getActorID(self):
        '''Returns an ID for this actor.
        
        The actor's id constructed from the host's IP address the actor's process id. 
        The id is unique for a given host and actor run.
        '''
        return self.actorID

    def setUUID(self, uuid):
        '''Sets the UUID for this actor.
        
        The UUID is dynamically generated (by the peer-to-peer network system)
        and is unique. 
        '''
        self.uuid = uuid

    def getUUID(self):
        '''Return the UUID for this actor. 
        '''
        return self.uuid

    def setupIfaces(self):
        '''Find the IP addresses of the (host-)local and network(-global) interfaces
        
        '''
        (globalIPs, globalMACs, _globalNames, localIP) = getNetworkInterfaces()
        try:
            assert len(globalIPs) > 0 and len(globalMACs) > 0
        except:
            self.logger.error("Error: no active network interface")
            raise
        globalIP = globalIPs[0]
        globalMAC = globalMACs[0]
        self.localHost = localIP
        self.globalHost = globalIP
        self.macAddress = globalMAC

    def setup(self):
        '''Perform a setup operation on the actor, after  the initial construction 
        but before the activation of parts
        
        '''
        self.logger.info("setup")
        self.suffix = self.macAddress
        self.disco = DiscoClient(self, self.suffix)
        self.disco.start()  # Start the discovery service client
        self.disco.registerApp(
        )  # Register this actor with the discovery service
        self.logger.info("actor registered with disco")
        self.deplc = DeplClient(self, self.suffix)
        self.deplc.start()
        ok = self.deplc.registerApp()
        self.logger.info("actor %s registered with depl" %
                         ("is" if ok else "is not"))

        self.controls = {}
        self.controlMap = {}
        for inst in self.components:
            comp = self.components[inst]
            control = self.context.socket(zmq.PAIR)
            control.bind('inproc://part_' + inst + '_control')
            self.controls[inst] = control
            self.controlMap[id(control)] = comp
            if isinstance(comp, Part):
                self.components[inst].setup(control)
            else:
                self.components[inst].setup()

    def registerEndpoint(self, bundle):
        '''
        Relay the endpoint registration message to the discovery service client 
        '''
        self.logger.info("registerEndpoint")
        result = self.disco.registerEndpoint(bundle)
        for res in result:
            (partName, portName, host, port) = res
            self.updatePart(partName, portName, host, port)

    def registerDevice(self, bundle):
        '''Relay the device registration message to the device interface service client
        
        '''
        typeName, args = bundle
        msg = (self.appName, self.modelName, typeName, args)
        result = self.deplc.registerDevice(msg)
        return result

    def unregisterDevice(self, bundle):
        '''Relay the device unregistration message to the device interface service client
        
        '''
        typeName, = bundle
        msg = (self.appName, self.modelName, typeName)
        result = self.deplc.unregisterDevice(msg)
        return result

    def activate(self):
        '''Activate the parts
        
        '''
        self.logger.info("activate")
        for inst in self.components:
            self.components[inst].activate()

    def deactivate(self):
        '''Deactivate the parts
        
        '''
        self.logger.info("deactivate")
        for inst in self.components:
            self.components[inst].deactivate()

    def recvChannelMessages(self, channel):
        '''Collect all messages from the channel queue and return them in a list
        '''
        msgs = []
        while True:
            try:
                msg = channel.recv(flags=zmq.NOBLOCK)
                msgs.append(msg)
            except zmq.Again:
                break
        return msgs

    def start(self):
        '''
        Start and operate the actor (infinite polling loop)
        '''
        self.logger.info("starting")
        self.discoChannel = self.disco.channel  # Private channel to the discovery service
        self.deplChannel = self.deplc.channel

        self.poller = zmq.Poller()  # Set up the poller
        self.poller.register(self.deplChannel, zmq.POLLIN)
        self.poller.register(self.discoChannel, zmq.POLLIN)
        for control in self.controls:
            self.poller.register(self.controls[control], zmq.POLLIN)

        while 1:
            sockets = dict(self.poller.poll())
            if self.discoChannel in sockets:  # If there is a message from a service, handle it
                msgs = self.recvChannelMessages(self.discoChannel)
                for msg in msgs:
                    self.handleServiceUpdate(
                        msg)  # Handle message from disco service
                del sockets[self.discoChannel]
            elif self.deplChannel in sockets:
                msgs = self.recvChannelMessages(self.deplChannel)
                for msg in msgs:
                    self.handleDeplMessage(
                        msg)  # Handle message from depl service
                del sockets[self.deplChannel]
            else:  # Handle messages from the components.
                toDelete = []
                for s in sockets:
                    if s in self.controls.values():
                        part = self.controlMap[id(s)]
                        msg = s.recv_pyobj(
                        )  # receive python object from component
                        self.handleEventReport(part, msg)  # Report event
                    toDelete += [s]
                for s in toDelete:
                    del sockets[s]

    def handleServiceUpdate(self, msgBytes):
        '''
        Handle a service update message from the discovery service
        '''
        msgUpd = disco_capnp.DiscoUpd.from_bytes(
            msgBytes)  # Parse the incoming message

        which = msgUpd.which()
        if which == 'portUpdate':
            msg = msgUpd.portUpdate
            client = msg.client
            actorHost = client.actorHost
            assert actorHost == self.globalHost  # It has to be addressed to this actor
            actorName = client.actorName
            assert actorName == self.name
            instanceName = client.instanceName
            assert instanceName in self.components  # It has to be for a part of this actor
            portName = client.portName
            scope = msg.scope
            socket = msg.socket
            host = socket.host
            port = socket.port
            if scope == "local":
                assert host == self.localHost
            self.updatePart(instanceName, portName, host,
                            port)  # Update the selected part

    def updatePart(self, instanceName, portName, host, port):
        '''
        Ask a part to update itself
        '''
        self.logger.info("updatePart %s" % str(
            (instanceName, portName, host, port)))
        part = self.components[instanceName]
        part.handlePortUpdate(portName, host, port)

    def handleDeplMessage(self, msgBytes):
        '''
        Handle a message from the deployment service
        '''
        msgUpd = deplo_capnp.DeplCmd.from_bytes(
            msgBytes)  # Parse the incoming message

        which = msgUpd.which()
        if which == 'resourceMsg':
            what = msgUpd.resourceMsg.which()
            if what == 'resCPUX':
                self.handleCPULimit()
            elif what == 'resMemX':
                self.handleMemLimit()
            elif what == 'resSpcX':
                self.handleSpcLimit()
            elif what == 'resNetX':
                self.handleNetLimit()
            else:
                self.logger.error("unknown resource msg from deplo: '%s'" %
                                  what)
                pass
        elif which == 'reinstateCmd':
            self.handleReinstate()
        elif which == 'nicStateMsg':
            stateMsg = msgUpd.nicStateMsg
            state = str(stateMsg.nicState)
            self.handleNICStateChange(state)
        elif which == 'peerInfoMsg':
            peerMsg = msgUpd.peerInfoMsg
            state = str(peerMsg.peerState)
            uuid = peerMsg.uuid
            self.handlePeerStateChange(state, uuid)
        else:
            self.logger.error("unknown msg from deplo: '%s'" % which)
            pass

    def handleReinstate(self):
        self.logger.info('handleReinstate')
        self.poller.unregister(self.discoChannel)
        self.disco.reconnect()
        self.discoChannel = self.disco.channel
        self.poller.register(self.discoChannel, zmq.POLLIN)
        for inst in self.components:
            self.components[inst].handleReinstate()

    def handleNICStateChange(self, state):
        '''
        Handle the NIC state change message: notify components   
        '''
        self.logger.info("handleNICStateChange")
        for component in self.components.values():
            component.handleNICStateChange(state)

    def handlePeerStateChange(self, state, uuid):
        '''
        Handle the peer state change message: notify components   
        '''
        self.logger.info("handlePeerStateChange")
        for component in self.components.values():
            component.handlePeerStateChange(state, uuid)

    def handleCPULimit(self):
        '''
        Handle the case when the CPU limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleCPULimit")
        for component in self.components.values():
            component.handleCPULimit()

    def handleMemLimit(self):
        '''
        Handle the case when the memory limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleMemLimit")
        for component in self.components.values():
            component.handleMemLimit()

    def handleSpcLimit(self):
        '''
        Handle the case when the file space limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleSpcLimit")
        for component in self.components.values():
            component.handleSpcLimit()

    def handleNetLimit(self):
        '''
        Handle the case when the net usage limit is exceeded: notify each component.
        If the component has defined a handler, it will be called.   
        '''
        self.logger.info("handleNetLimit")
        for component in self.components.values():
            component.handleNetLimit()

    def handleEventReport(self, part, msg):
        '''Handle event report from a part
        
        The event report is forwarded to the deplo service. 
        '''
        partName = part.getName()
        typeName = part.getTypeName()
        bundle = (
            partName,
            typeName,
        ) + (msg, )
        self.deplc.reportEvent(bundle)

    def terminate(self):
        '''Terminate all functions of the actor. 
        
        Terminate all components, and connections to the deplo/disco services.
        Finally exit the process. 
        '''
        self.logger.info("terminating")
        for component in self.components.values():
            component.terminate()
        time.sleep(1.0)
        self.deplc.terminate()
        self.disco.terminate()
        if self.auth:
            self.auth.stop()
        # Clean up everything
        # self.context.destroy()
        # time.sleep(1.0)
        self.logger.info("terminated")
        os._exit(0)