Пример #1
0
class MispZmq:
    message_count = 0
    publish_count = 0

    monitor_thread = None
    auth = None
    socket = None
    pidfile = None

    r: redis.StrictRedis
    namespace: str

    def __init__(self):
        self._logger = logging.getLogger()

        self.tmp_location = Path(__file__).parent.parent / "tmp"
        self.pidfile = self.tmp_location / "mispzmq.pid"
        if self.pidfile.exists():
            with open(self.pidfile.as_posix()) as f:
                pid = f.read()
            if check_pid(pid):
                raise Exception(
                    "mispzmq already running on PID {}".format(pid))
            else:
                # Cleanup
                self.pidfile.unlink()
        if (self.tmp_location / "mispzmq_settings.json").exists():
            self._setup()
        else:
            raise Exception("The settings file is missing.")

    def _setup(self):
        with open((self.tmp_location /
                   "mispzmq_settings.json").as_posix()) as settings_file:
            self.settings = json.load(settings_file)
        self.namespace = self.settings["redis_namespace"]
        self.r = redis.StrictRedis(host=self.settings["redis_host"],
                                   db=self.settings["redis_database"],
                                   password=self.settings["redis_password"],
                                   port=self.settings["redis_port"],
                                   decode_responses=True)
        self.timestamp_settings = time.time()
        self._logger.debug("Connected to Redis {}:{}/{}".format(
            self.settings["redis_host"], self.settings["redis_port"],
            self.settings["redis_database"]))

    def _setup_zmq(self):
        context = zmq.Context()

        if "username" in self.settings and self.settings["username"]:
            if "password" not in self.settings or not self.settings["password"]:
                raise Exception(
                    "When username is set, password cannot be empty.")

            self.auth = ThreadAuthenticator(context)
            self.auth.start()
            self.auth.configure_plain(domain="*",
                                      passwords={
                                          self.settings["username"]:
                                          self.settings["password"]
                                      })
        else:
            if self.auth:
                self.auth.stop()
            self.auth = None

        self.socket = context.socket(zmq.PUB)
        if self.settings["username"]:
            self.socket.plain_server = True  # must come before bind
        self.socket.bind("tcp://{}:{}".format(self.settings["host"],
                                              self.settings["port"]))
        self._logger.debug("ZMQ listening on tcp://{}:{}".format(
            self.settings["host"], self.settings["port"]))

        if self._logger.isEnabledFor(logging.DEBUG):
            monitor = self.socket.get_monitor_socket()
            self.monitor_thread = threading.Thread(target=event_monitor,
                                                   args=(monitor,
                                                         self._logger))
            self.monitor_thread.start()
        else:
            if self.monitor_thread:
                self.socket.disable_monitor()
            self.monitor_thread = None

    def _handle_command(self, command):
        if command == "kill":
            self._logger.info("Kill command received, shutting down.")
            self.clean()
            sys.exit()

        elif command == "reload":
            self._logger.info(
                "Reload command received, reloading settings from file.")
            self._setup()
            self._setup_zmq()

        elif command == "status":
            self._logger.info(
                "Status command received, responding with latest stats.")
            self.r.delete("{}:status".format(self.namespace))
            self.r.lpush(
                "{}:status".format(self.namespace),
                json.dumps({
                    "timestamp": time.time(),
                    "timestampSettings": self.timestamp_settings,
                    "publishCount": self.publish_count,
                    "messageCount": self.message_count
                }))
        else:
            self._logger.warning(
                "Received invalid command '{}'.".format(command))

    def _create_pid_file(self):
        with open(self.pidfile.as_posix(), "w") as f:
            f.write(str(os.getpid()))

    def _pub_message(self, topic, data):
        self.socket.send_string("{} {}".format(topic, data))

    def clean(self):
        if self.monitor_thread:
            self.socket.disable_monitor()
        if self.auth:
            self.auth.stop()
        if self.socket:
            self.socket.close()
        if self.pidfile:
            self.pidfile.unlink()

    def main(self):
        self._create_pid_file()
        self._setup_zmq()
        time.sleep(1)

        status_array = [
            "And when you're dead I will be still alive.",
            "And believe me I am still alive.",
            "I'm doing science and I'm still alive.",
            "I feel FANTASTIC and I'm still alive.",
            "While you're dying I'll be still alive.",
        ]
        topics = [
            "misp_json", "misp_json_event", "misp_json_attribute",
            "misp_json_sighting", "misp_json_organisation", "misp_json_user",
            "misp_json_conversation", "misp_json_object",
            "misp_json_object_reference", "misp_json_audit", "misp_json_tag",
            "misp_json_warninglist"
        ]

        lists = ["{}:command".format(self.namespace)]
        for topic in topics:
            lists.append("{}:data:{}".format(self.namespace, topic))

        while True:
            data = self.r.blpop(lists, timeout=10)

            if data is None:
                # redis timeout expired
                current_time = int(time.time())
                time_delta = current_time - int(self.timestamp_settings)
                status_entry = int(time_delta / 10 % 5)
                status_message = {
                    "status": status_array[status_entry],
                    "uptime": current_time - int(self.timestamp_settings)
                }
                self._pub_message("misp_json_self", json.dumps(status_message))
                self._logger.debug(
                    "No message received for 10 seconds, sending ZMQ status message."
                )
            else:
                key, value = data
                key = key.replace("{}:".format(self.namespace), "")
                if key == "command":
                    self._handle_command(value)
                elif key.startswith("data:"):
                    topic = key.split(":")[1]
                    self._logger.debug(
                        "Received data for topic '{}', sending to ZMQ.".format(
                            topic))
                    self._pub_message(topic, value)
                    self.message_count += 1
                    if topic == "misp_json":
                        self.publish_count += 1
                else:
                    self._logger.warning(
                        "Received invalid message '{}'.".format(key))
Пример #2
0
class RpcClient:
    """"""
    def __init__(self):
        """Constructor"""
        # zmq port related
        self.__context: zmq.Context = zmq.Context()

        # Request socket (Request–reply pattern)
        self.__socket_req: zmq.Socket = self.__context.socket(zmq.REQ)

        # Subscribe socket (Publish–subscribe pattern)
        self.__socket_sub: zmq.Socket = self.__context.socket(zmq.SUB)

        # Worker thread relate, used to process data pushed from server
        self.__active: bool = False  # RpcClient status
        self.__thread: threading.Thread = None  # RpcClient thread
        self.__lock: threading.Lock = threading.Lock()

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

        self._last_received_ping: datetime = datetime.utcnow()

    @lru_cache(100)
    def __getattr__(self, name: str):
        """
        Realize remote call function
        """

        # Perform remote call task
        def dorpc(*args, **kwargs):
            # Get timeout value from kwargs, default value is 30 seconds
            if "timeout" in kwargs:
                timeout = kwargs.pop("timeout")
            else:
                timeout = 30000

            # Generate request
            req = [name, args, kwargs]

            # Send request and wait for response
            with self.__lock:
                self.__socket_req.send_pyobj(req)

                # Timeout reached without any data
                n = self.__socket_req.poll(timeout)
                if not n:
                    msg = f"Timeout of {timeout}ms reached for {req}"
                    raise RemoteException(msg)

                rep = self.__socket_req.recv_pyobj()

            # Return response if successed; Trigger exception if failed
            if rep[0]:
                return rep[1]
            else:
                raise RemoteException(rep[1])

        return dorpc

    def start(self,
              req_address: str,
              sub_address: str,
              client_secretkey_path: str = "",
              server_publickey_path: str = "",
              username: str = "",
              password: str = "") -> None:
        """
        Start RpcClient
        """
        if self.__active:
            return

        # Start authenticator
        if client_secretkey_path and server_publickey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                client_secretkey_path)
            serverkey, _ = zmq.auth.load_certificate(server_publickey_path)

            self.__socket_sub.curve_secretkey = secretkey
            self.__socket_sub.curve_publickey = publickey
            self.__socket_sub.curve_serverkey = serverkey

            self.__socket_req.curve_secretkey = secretkey
            self.__socket_req.curve_publickey = publickey
            self.__socket_req.curve_serverkey = serverkey
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_sub.plain_username = username.encode()
            self.__socket_sub.plain_password = password.encode()

            self.__socket_req.plain_username = username.encode()
            self.__socket_req.plain_password = password.encode()

        # Connect zmq port
        self.__socket_req.connect(req_address)
        self.__socket_sub.connect(sub_address)

        # Start RpcClient status
        self.__active = True

        # Start RpcClient thread
        self.__thread = threading.Thread(target=self.run)
        self.__thread.start()

        self._last_received_ping = datetime.utcnow()

    def stop(self) -> None:
        """
        Stop RpcClient
        """
        if not self.__active:
            return

        # Stop RpcClient status
        self.__active = False

    def join(self) -> None:
        # Wait for RpcClient thread to exit
        if self.__thread and self.__thread.is_alive():
            self.__thread.join()
        self.__thread = None

    def run(self) -> None:
        """
        Run RpcClient function
        """
        pull_tolerance = int(KEEP_ALIVE_TOLERANCE.total_seconds() * 1000)

        while self.__active:
            if not self.__socket_sub.poll(pull_tolerance):
                self.on_disconnected()
                continue

            # Receive data from subscribe socket
            topic, data = self.__socket_sub.recv_pyobj(flags=NOBLOCK)

            if topic == KEEP_ALIVE_TOPIC:
                self._last_received_ping = data
            else:
                # Process data by callable function
                self.callback(topic, data)

        # Close socket
        self.__socket_req.close()
        self.__socket_sub.close()

    def callback(self, topic: str, data: Any) -> None:
        """
        Callable function
        """
        raise NotImplementedError

    def subscribe_topic(self, topic: str) -> None:
        """
        Subscribe data
        """
        self.__socket_sub.setsockopt_string(zmq.SUBSCRIBE, topic)

    def on_disconnected(self):
        """
        Callback when heartbeat is lost.
        """
        print(
            "RpcServer has no response over {tolerance} seconds, please check you connection."
            .format(tolerance=KEEP_ALIVE_TOLERANCE.total_seconds()))
Пример #3
0
class RpcServer:
    """"""
    def __init__(self):
        """
        Constructor
        """
        # Save functions dict: key is fuction name, value is fuction object
        self.__functions: Dict[str, Any] = {}

        # Zmq port related
        self.__context: zmq.Context = zmq.Context()

        # Reply socket (Request–reply pattern)
        self.__socket_rep: zmq.Socket = self.__context.socket(zmq.REP)

        # Publish socket (Publish–subscribe pattern)
        self.__socket_pub: zmq.Socket = self.__context.socket(zmq.PUB)

        # Worker thread related
        self.__active: bool = False  # RpcServer status
        self.__thread: threading.Thread = None  # RpcServer thread
        self.__lock: threading.Lock = threading.Lock()

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

    def is_active(self) -> bool:
        """"""
        return self.__active

    def start(self,
              rep_address: str,
              pub_address: str,
              server_secretkey_path: str = "",
              username: str = "",
              password: str = "") -> None:
        """
        Start RpcServer
        """
        if self.__active:
            return

        # Start authenticator
        if server_secretkey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                server_secretkey_path)

            self.__socket_pub.curve_secretkey = secretkey
            self.__socket_pub.curve_publickey = publickey
            self.__socket_pub.curve_server = True

            self.__socket_rep.curve_secretkey = secretkey
            self.__socket_rep.curve_publickey = publickey
            self.__socket_rep.curve_server = True
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_pub.plain_server = True
            self.__socket_rep.plain_server = True

        # Bind socket address
        self.__socket_rep.bind(rep_address)
        self.__socket_pub.bind(pub_address)

        # Start RpcServer status
        self.__active = True

        # Start RpcServer thread
        self.__thread = threading.Thread(target=self.run)
        self.__thread.start()

    def stop(self) -> None:
        """
        Stop RpcServer
        """
        if not self.__active:
            return

        # Stop RpcServer status
        self.__active = False

    def join(self) -> None:
        # Wait for RpcServer thread to exit
        if self.__thread and self.__thread.is_alive():
            self.__thread.join()
        self.__thread = None

    def run(self) -> None:
        """
        Run RpcServer functions
        """
        start = datetime.utcnow()

        while self.__active:
            # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
            cur = datetime.utcnow()
            delta = cur - start

            if delta >= KEEP_ALIVE_INTERVAL:
                self.publish(KEEP_ALIVE_TOPIC, cur)

            if not self.__socket_rep.poll(1000):
                continue

            # Receive request data from Reply socket
            req = self.__socket_rep.recv_pyobj()

            # Get function name and parameters
            name, args, kwargs = req

            # Try to get and execute callable function object; capture exception information if it fails
            try:
                func = self.__functions[name]
                r = func(*args, **kwargs)
                rep = [True, r]
            except Exception as e:  # noqa
                rep = [False, traceback.format_exc()]

            # send callable response by Reply socket
            self.__socket_rep.send_pyobj(rep)

        # Unbind socket address
        self.__socket_pub.unbind(self.__socket_pub.LAST_ENDPOINT)
        self.__socket_rep.unbind(self.__socket_rep.LAST_ENDPOINT)

    def publish(self, topic: str, data: Any) -> None:
        """
        Publish data
        """
        with self.__lock:
            self.__socket_pub.send_pyobj([topic, data])

    def register(self, func: Callable) -> None:
        """
        Register function
        """
        self.__functions[func.__name__] = func
Пример #4
0
class FeatureComputer(object):
    def __init__(self,
                 bind_str="tcp://127.0.0.1:5560",
                 parent_model=None,
                 layer=None,
                 logins=None,
                 viewable_layers=None):
        self.context = zmq.Context.instance()
        self.auth = ThreadAuthenticator(self.context)
        self.auth.start()
        #auth.allow('127.0.0.1')
        self.auth.configure_plain(domain='*', passwords=logins)
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.plain_server = True
        self.socket.bind(bind_str)
        self.parent_model = parent_model
        self.curr_model = parent_model
        self.viewable_layers = viewable_layers

        self.config = tf.ConfigProto()
        self.config.gpu_options.per_process_gpu_memory_fraction = 0.3
        self.config.gpu_options.allow_growth = True

        if not layer:
            self.layer = 5  #len(self.parent_model.layers) - 1
        else:
            self.layer = layer

    def change_layer(self, *args, **kwargs):
        print("Changing layer")
        print(args, kwargs)
        self.layer = kwargs.get('layer', self.layer)
        if not self.viewable_layers:
            self.curr_model = Model(
                input=[self.parent_model.layers[0].input],
                output=[self.parent_model.layers[self.layer].output])
        else:
            self.curr_model = Model(
                input=[self.parent_model.layers[0].input],
                output=[self.viewable_layers[self.layer].output])
        #set_session(tf.Session(config=self.config))
        print(self.layer)
        self.socket.send_pyobj({
            'type': 'layer_changed',
            'success': True
        }, zmq.NOBLOCK)

    def get_summary(self, *args, **kwargs):
        self.socket.send_pyobj({
            'type': 'summary',
            'result': None
        }, zmq.NOBLOCK)

    def do_predict(self, *args, **kwargs):
        # TODO: Make this configurable
        input = kwargs.pop('input', np.zeros((1, 224, 224, 3)))

        resized = np.float64(cv2.resize(input, (224, 224)))
        preprocessed = preprocess_input(np.expand_dims(resized, axis=0))

        result = self.curr_model.predict(preprocessed, verbose=0)
        self.socket.send_pyobj({
            'type': 'prediction',
            'result': result
        }, zmq.NOBLOCK)

    def do_layerinfo(self, *args, **kwargs):
        self.socket.send_pyobj(
            {
                'type': 'layer_info',
                'shape': self.curr_model.compute_output_shape(
                    (1, 224, 224, 3)),
                'name': self.parent_model.layers[self.layer].name
            }, zmq.NOBLOCK)

    def do_summary(self, *args, **kwargs):
        if not self.viewable_layers:
            self.socket.send_pyobj(
                {
                    'type': 'summary',
                    'result':
                    [layer.name for layer in self.parent_model.layers]
                }, zmq.NOBLOCK)
        else:
            self.socket.send_pyobj(
                {
                    'type': 'summary',
                    'result': [layer.name for layer in self.viewable_layers]
                }, zmq.NOBLOCK)

    def handle_message(self, message):
        if message['type'] == "change_layer":
            self.change_layer(**message)
        if message['type'] == 'predict':
            self.do_predict(**message)
        if message['type'] == 'layer_info':
            self.do_layerinfo(**message)
        if message['type'] == 'summary':
            self.do_summary(**message)

    def run(self):
        self.running = True
        while self.running:
            message = self.socket.recv_pyobj()
            self.handle_message(message)
Пример #5
0
class ZmqReceiver(object):
    def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None):
        self.context = zmq.Context()
        self.auth = None
        self.last_received_message = None
        self.is_running = False
        self.thread = None
        self.zmq_rep_bind_address = zmq_rep_bind_address
        self.zmq_sub_connect_addresses = zmq_sub_connect_addresses
        self.poller = zmq.Poller()
        self.sub_sockets = []
        self.rep_socket = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I (probably because the more secure solutions
            # require two way communication as well)
            self.auth = ThreadAuthenticator(self.context)
            self.auth.start()
            # Instruct authenticator to handle PLAIN requests
            self.auth.configure_plain(domain='*', passwords={username: password})

        if self.zmq_sub_connect_addresses:
            for address in self.zmq_sub_connect_addresses:
                self.sub_sockets.append(SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec))
        if zmq_rep_bind_address:
            self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth)

    # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds
    def stop(self):
        self.is_running = False
        logger.info("Closing pub and sub sockets...")
        if self.auth is not None:
            self.auth.stop()

    def run(self):
        self.is_running = True

        while self.is_running:
            socks = dict(self.poller.poll(1000))
            logger.debug("Poll cycle over. checking sockets")
            if self.rep_socket:
                incoming_message = self.rep_socket.recv_string(socks)
                if incoming_message is not None:
                    self.last_received_message = incoming_message
                    try:
                        logger.debug("Got info from REP socket")
                        response_message = self.handle_incoming_message(incoming_message)
                        self.rep_socket.send(response_message)
                    except Exception as e:
                        logger.error(e)
            for sub_socket in self.sub_sockets:
                incoming_message = sub_socket.recv_string(socks)
                if incoming_message is not None:
                    if incoming_message != "zmq_sub_heartbeat":
                        self.last_received_message = incoming_message
                    logger.debug("Got info from SUB socket")
                    try:
                        self.handle_incoming_message(incoming_message)
                    except Exception as e:
                        logger.error(e)

        if self.rep_socket:
            self.rep_socket.destroy()
        for sub_socket in self.sub_sockets:
            sub_socket.destroy()

    def create_response_message(self, status_code, status_message, response_message):
        if response_message is not None:
            return json.dumps({"status_code": status_code, "status_message": status_message, "response_message": response_message})
        else:
            return json.dumps({"status_code": status_code, "status_message": status_message})

    def handle_incoming_message(self, message):
        if message != "zmq_sub_heartbeat":
            return self.create_response_message(200, "OK", None)
Пример #6
0
class Driver(drivers.BaseDriver):
    def __init__(
        self,
        args,
        encrypted_traffic_data=None,
        interface=None,
    ):
        """Initialize the Driver.

        :param args: Arguments parsed by argparse.
        :type args: Object
        :param encrypted_traffic: Enable|Disable encrypted traffic.
        :type encrypted_traffic: Boolean
        :param interface: The interface instance (client/server)
        :type interface: Object
        """

        self.thread_processor = multiprocessing.Process
        self.event = multiprocessing.Event()
        self.semaphore = multiprocessing.Semaphore
        self.flushqueue = _FlushQueue
        self.args = args
        if getattr(self.args, "zmq_generate_keys", False) is True:
            self._generate_certificates()
            print("New certificates generated")
            raise SystemExit(0)

        self.encrypted_traffic_data = encrypted_traffic_data

        mode = getattr(self.args, "mode", None)
        if mode == "client":
            self.bind_address = self.args.zmq_server_address
        elif mode == "server":
            self.bind_address = self.args.zmq_bind_address
        else:
            self.bind_address = "*"
        self.proto = "tcp"
        self.connection_string = "{proto}://{addr}".format(
            proto=self.proto, addr=self.bind_address)

        if self.encrypted_traffic_data:
            self.encrypted_traffic = self.encrypted_traffic_data.get("enabled")
            self.secret_keys_dir = self.encrypted_traffic_data.get(
                "secret_keys_dir")
            self.public_keys_dir = self.encrypted_traffic_data.get(
                "public_keys_dir")
        else:
            self.encrypted_traffic = False
            self.secret_keys_dir = None
            self.public_keys_dir = None

        self._context = zmq.Context()
        self.ctx = self._context.instance()
        self.poller = zmq.Poller()
        self.interface = interface
        super(Driver, self).__init__(
            args=args,
            encrypted_traffic_data=self.encrypted_traffic_data,
            interface=interface,
        )
        self.bind_job = None
        self.bind_backend = None
        self.hwm = getattr(self.args, "zmq_highwater_mark", 1024)

    def __copy__(self):
        """Return a new copy of the driver."""

        return Driver(
            args=self.args,
            encrypted_traffic_data=self.encrypted_traffic_data,
            interface=self.interface,
        )

    def _backend_bind(self):
        """Bind an address to a backend socket and return the socket.

        :returns: Object
        """

        bind = self._socket_bind(
            socket_type=zmq.ROUTER,
            connection=self.connection_string,
            port=self.args.backend_port,
        )
        bind.set_hwm(self.hwm)
        self.log.debug(
            "Identity [ %s ] backend connect hwm state [ %s ]",
            self.identity,
            bind.get_hwm(),
        )
        return bind

    def _backend_connect(self):
        """Connect to a backend socket and return the socket.

        :returns: Object
        """

        self.log.debug("Establishing backend connection.")
        bind = self._socket_connect(
            socket_type=zmq.DEALER,
            connection=self.connection_string,
            port=self.args.backend_port,
        )
        bind.set_hwm(self.hwm)
        self.log.debug(
            "Identity [ %s ] backend connect hwm state [ %s ]",
            self.identity,
            bind.get_hwm(),
        )
        return bind

    def _bind_check(self, bind, interval=1, constant=1000):
        """Return True if a bind type contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        socks = dict(self.poller.poll(interval * constant))
        if socks.get(bind) == zmq.POLLIN:
            return True
        else:
            return False

    def _close(self, socket):
        if socket is None:
            return

        try:
            socket.close(linger=2)
            close_time = time.time()
            while not socket.closed:
                if time.time() - close_time > 60:
                    raise TimeoutError(
                        "Job [ {} ] failed to close transfer socket".format(
                            self.job_id))
                else:
                    socket.close(linger=2)
                    time.sleep(1)
        except Exception as e:
            self.log.error(
                "Ran into an exception while closing the socket %s",
                str(e),
            )
        else:
            self.log.debug("Backend socket closed")

    def _generate_certificates(self, base_dir="/etc/directord"):
        """Generate client and server CURVE certificate files.

        :param base_dir: Directord configuration path.
        :type base_dir: String
        """

        keys_dir = os.path.join(base_dir, "certificates")
        public_keys_dir = os.path.join(base_dir, "public_keys")
        secret_keys_dir = os.path.join(base_dir, "private_keys")

        for item in [keys_dir, public_keys_dir, secret_keys_dir]:
            os.makedirs(item, exist_ok=True)

        # Run certificate backup
        self._move_certificates(directory=public_keys_dir, backup=True)
        self._move_certificates(directory=secret_keys_dir,
                                backup=True,
                                suffix=".key_secret")

        # create new keys in certificates dir
        for item in ["server", "client"]:
            self._key_generate(keys_dir=keys_dir, key_type=item)

        # Move generated certificates in place
        self._move_certificates(
            directory=keys_dir,
            target_directory=public_keys_dir,
            suffix=".key",
        )
        self._move_certificates(
            directory=keys_dir,
            target_directory=secret_keys_dir,
            suffix=".key_secret",
        )

    def _job_bind(self):
        """Bind an address to a job socket and return the socket.

        :returns: Object
        """

        return self._socket_bind(
            socket_type=zmq.ROUTER,
            connection=self.connection_string,
            port=self.args.job_port,
        )

    def _job_connect(self):
        """Connect to a job socket and return the socket.

        :returns: Object
        """

        self.log.debug("Establishing Job connection.")
        return self._socket_connect(
            socket_type=zmq.DEALER,
            connection=self.connection_string,
            port=self.args.job_port,
        )

    def _key_generate(self, keys_dir, key_type):
        """Generate certificate.

        :param keys_dir: Full Directory path where a given key will be stored.
        :type keys_dir: String
        :param key_type: Key type to be generated.
        :type key_type: String
        """

        zmq_auth.create_certificates(keys_dir, key_type)

    @staticmethod
    def _move_certificates(directory,
                           target_directory=None,
                           backup=False,
                           suffix=".key"):
        """Move certificates when required.

        :param directory: Set the origin path.
        :type directory: String
        :param target_directory: Set the target path.
        :type target_directory: String
        :param backup: Enable file backup before moving.
        :type backup:  Boolean
        :param suffix: Set the search suffix
        :type suffix: String
        """

        for item in os.listdir(directory):
            if backup:
                target_file = "{}.bak".format(os.path.basename(item))
            else:
                target_file = os.path.basename(item)

            if item.endswith(suffix):
                os.rename(
                    os.path.join(directory, item),
                    os.path.join(target_directory or directory, target_file),
                )

    def _socket_bind(self, socket_type, connection, port, poller_type=None):
        """Return a socket object which has been bound to a given address.

        When the socket_type is not PUB or PUSH, the bound socket will also be
        registered with self.poller as defined within the Interface
        class.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :param connection: Set the Address information used for the bound
                           socket.
        :type connection: String
        :param port: Define the port which the socket will be bound to.
        :type port: Integer
        :param poller_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type poller_type: Integer
        :returns: Object
        """

        if poller_type is None:
            poller_type = zmq.POLLIN

        bind = self._socket_context(socket_type=socket_type)
        auth_enabled = (self.args.zmq_shared_key
                        or self.args.zmq_curve_encryption)

        if auth_enabled:
            self.auth = ThreadAuthenticator(self.ctx, log=self.log)
            self.auth.start()
            self.auth.allow()

            if self.args.zmq_shared_key:
                # Enables basic auth
                self.auth.configure_plain(
                    domain="*", passwords={"admin": self.args.zmq_shared_key})
                bind.plain_server = True  # Enable shared key authentication
                self.log.info("Shared key authentication enabled.")
            elif self.args.zmq_curve_encryption:
                server_secret_file = os.path.join(self.secret_keys_dir,
                                                  "server.key_secret")
                for item in [
                        self.public_keys_dir,
                        self.secret_keys_dir,
                        server_secret_file,
                ]:
                    if not os.path.exists(item):
                        raise SystemExit(
                            "The required path [ {} ] does not exist. Have"
                            " you generated your keys?".format(item))
                self.auth.configure_curve(domain="*",
                                          location=self.public_keys_dir)
                try:
                    server_public, server_secret = zmq_auth.load_certificate(
                        server_secret_file)
                except OSError as e:
                    self.log.error(
                        "Failed to load certificates: %s, Configuration: %s",
                        str(e),
                        vars(self.args),
                    )
                    raise SystemExit("Failed to load certificates")
                else:
                    bind.curve_secretkey = server_secret
                    bind.curve_publickey = server_public
                    bind.curve_server = True  # Enable curve authentication
        bind.bind("{connection}:{port}".format(
            connection=connection,
            port=port,
        ))

        if socket_type not in [zmq.PUB]:
            self.poller.register(bind, poller_type)

        return bind

    def _socket_connect(self, socket_type, connection, port, poller_type=None):
        """Return a socket object which has been bound to a given address.

        > A connection back to the server will wait 10 seconds for an ack
          before going into a retry loop. This is done to forcefully cycle
          the connection object to reset.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :param connection: Set the Address information used for the bound
                           socket.
        :type connection: String
        :param port: Define the port which the socket will be bound to.
        :type port: Integer
        :param poller_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type poller_type: Integer
        :returns: Object
        """

        if poller_type is None:
            poller_type = zmq.POLLIN

        bind = self._socket_context(socket_type=socket_type)

        if self.args.zmq_shared_key:
            bind.plain_username = b"admin"  # User is hard coded.
            bind.plain_password = self.args.zmq_shared_key.encode()
            self.log.info("Shared key authentication enabled.")
        elif self.args.zmq_curve_encryption:
            client_secret_file = os.path.join(self.secret_keys_dir,
                                              "client.key_secret")
            server_public_file = os.path.join(self.public_keys_dir,
                                              "server.key")
            for item in [
                    self.public_keys_dir,
                    self.secret_keys_dir,
                    client_secret_file,
                    server_public_file,
            ]:
                if not os.path.exists(item):
                    raise SystemExit(
                        "The required path [ {} ] does not exist. Have"
                        " you generated your keys?".format(item))
            try:
                client_public, client_secret = zmq_auth.load_certificate(
                    client_secret_file)
                server_public, _ = zmq_auth.load_certificate(
                    server_public_file)
            except OSError as e:
                self.log.error(
                    "Error while loading certificates: %s. Configuration: %s",
                    str(e),
                    vars(self.args),
                )
                raise SystemExit("Failed to load keys.")
            else:
                bind.curve_secretkey = client_secret
                bind.curve_publickey = client_public
                bind.curve_serverkey = server_public

        if socket_type == zmq.SUB:
            bind.setsockopt_string(zmq.SUBSCRIBE, self.identity)
        else:
            bind.setsockopt_string(zmq.IDENTITY, self.identity)

        self.poller.register(bind, poller_type)
        bind.connect("{connection}:{port}".format(
            connection=connection,
            port=port,
        ))

        self.log.info("Socket connected to [ %s ].", connection)
        return bind

    def _socket_context(self, socket_type):
        """Create socket context and return a bind object.

        :param socket_type: Set the Socket type, typically defined using a
                            ZeroMQ constant.
        :type socket_type: Integer
        :returns: Object
        """

        bind = self.ctx.socket(socket_type)
        bind.linger = getattr(self.args, "heartbeat_interval", 60)
        hwm = int(self.hwm * 4)
        try:
            bind.sndhwm = bind.rcvhwm = hwm
        except AttributeError:
            bind.hwm = hwm

        bind.set_hwm(hwm)
        bind.setsockopt(zmq.SNDHWM, hwm)
        bind.setsockopt(zmq.RCVHWM, hwm)
        if socket_type == zmq.ROUTER:
            bind.setsockopt(zmq.ROUTER_MANDATORY, 1)

        return bind

    @staticmethod
    def _socket_recv(socket, nonblocking=False):
        """Receive a message over a ZM0 socket.

        The message specification for server is as follows.

            [
                b"Identity"
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        The message specification for client is as follows.

            [
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        All message parts are byte encoded.

        All possible control characters are defined within the Interface class.
        For more on control characters review the following
        URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl).

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        """

        if nonblocking:
            flags = zmq.NOBLOCK
        else:
            flags = 0

        return socket.recv_multipart(flags=flags)

    @tenacity.retry(
        retry=tenacity.retry_if_exception_type(Exception),
        wait=tenacity.wait_fixed(5),
        before_sleep=tenacity.before_sleep_log(
            logger.getLogger(name="directord"), logging.WARN),
    )
    def _socket_send(
        self,
        socket,
        identity=None,
        msg_id=None,
        control=None,
        command=None,
        data=None,
        info=None,
        stderr=None,
        stdout=None,
        nonblocking=False,
    ):
        """Send a message over a ZM0 socket.

        The message specification for server is as follows.

            [
                b"Identity"
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        The message specification for client is as follows.

            [
                b"ID",
                b"ASCII Control Characters",
                b"command",
                b"data",
                b"info",
                b"stderr",
                b"stdout",
            ]

        All message information is assumed to be byte encoded.

        All possible control characters are defined within the Interface class.
        For more on control characters review the following
        URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl).

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param identity: Target where message will be sent.
        :type identity: Bytes
        :param msg_id: ID information for a given message. If no ID is
                       provided a UUID will be generated.
        :type msg_id: Bytes
        :param control: ASCII control charaters.
        :type control: Bytes
        :param command: Command definition for a given message.
        :type command: Bytes
        :param data: Encoded data that will be transmitted.
        :type data: Bytes
        :param info: Encoded information that will be transmitted.
        :type info: Bytes
        :param stderr: Encoded error information from a command.
        :type stderr: Bytes
        :param stdout: Encoded output information from a command.
        :type stdout: Bytes
        :param nonblocking: Enable non-blocking send.
        :type nonblocking: Boolean
        :returns: Object
        """
        def _encoder(item):
            try:
                return item.encode()
            except AttributeError:
                return item

        if not msg_id:
            msg_id = utils.get_uuid()

        if not control:
            control = self.nullbyte

        if not command:
            command = self.nullbyte

        if not data:
            data = self.nullbyte

        if not info:
            info = self.nullbyte

        if not stderr:
            stderr = self.nullbyte

        if not stdout:
            stdout = self.nullbyte

        message_parts = [msg_id, control, command, data, info, stderr, stdout]

        if identity:
            message_parts.insert(0, identity)

        message_parts = [_encoder(i) for i in message_parts]

        if nonblocking:
            flags = zmq.NOBLOCK
        else:
            flags = 0

        try:
            return socket.send_multipart(message_parts, flags=flags)
        except Exception as e:
            self.log.warn("Failed to send message to [ %s ]", identity)
            raise e

    def _recv(self, socket, nonblocking=False):
        """Receive message.

        :param socket: ZeroMQ socket object.
        :type socket: Object
        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        recv_obj = self._socket_recv(socket=socket, nonblocking=nonblocking)
        return tuple([i.decode() for i in recv_obj])

    def backend_recv(self, nonblocking=False):
        """Receive a transfer message.

        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        return self._recv(socket=self.bind_backend, nonblocking=nonblocking)

    def backend_init(self):
        """Initialize the backend socket.

        For server mode, this is a bound local socket.
        For client mode, it is a connection to the server socket.

        :returns: Object
        """

        if self.args.mode == "server":
            self.bind_backend = self._backend_bind()
        else:
            self.bind_backend = self._backend_connect()

    def backend_close(self):
        """Close the backend socket."""

        self._close(socket=self.bind_backend)

    def backend_check(self, interval=1, constant=1000):
        """Return True if the backend contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        return self._bind_check(bind=self.bind_backend,
                                interval=interval,
                                constant=constant)

    def backend_send(self, *args, **kwargs):
        """Send a job message.

        * All args and kwargs are passed through to the socket send.

        :returns: Object
        """

        kwargs["socket"] = self.bind_backend
        return self._socket_send(*args, **kwargs)

    @staticmethod
    def get_lock():
        """Returns a thread lock."""

        return multiprocessing.Lock()

    def heartbeat_send(self,
                       host_uptime=None,
                       agent_uptime=None,
                       version=None,
                       driver=None):
        """Send a heartbeat.

        :param host_uptime: Sender uptime
        :type host_uptime: String
        :param agent_uptime: Sender agent uptime
        :type agent_uptime: String
        :param version: Sender directord version
        :type version: String
        :param version: Driver information
        :type version: String
        """

        job_id = utils.get_uuid()
        self.log.info(
            "Job [ %s ] sending heartbeat from [ %s ] to server",
            job_id,
            self.identity,
        )

        return self.job_send(
            control=self.heartbeat_notice,
            msg_id=job_id,
            data=json.dumps({
                "job_id": job_id,
                "version": version,
                "host_uptime": host_uptime,
                "agent_uptime": agent_uptime,
                "machine_id": self.machine_id,
                "driver": driver,
            }),
        )

    def job_send(self, *args, **kwargs):
        """Send a job message.

        * All args and kwargs are passed through to the socket send.

        :returns: Object
        """

        kwargs["socket"] = self.bind_job
        return self._socket_send(*args, **kwargs)

    def job_recv(self, nonblocking=False):
        """Receive a transfer message.

        :param nonblocking: Enable non-blocking receve.
        :type nonblocking: Boolean
        :returns: Tuple
        """

        return self._recv(socket=self.bind_job, nonblocking=nonblocking)

    def job_init(self):
        """Initialize the job socket.

        For server mode, this is a bound local socket.
        For client mode, it is a connection to the server socket.

        :returns: Object
        """

        if self.args.mode == "server":
            self.bind_job = self._job_bind()
        else:
            self.bind_job = self._job_connect()

    def job_close(self):
        """Close the job socket."""

        self._close(socket=self.bind_job)

    def job_check(self, interval=1, constant=1000):
        """Return True if a job contains work ready.

        :param bind: A given Socket bind to identify.
        :type bind: Object
        :param interval: Exponential Interval used to determine the polling
                         duration for a given socket.
        :type interval: Integer
        :param constant: Constant time used to poll for new jobs.
        :type constant: Integer
        :returns: Object
        """

        return self._bind_check(bind=self.bind_job,
                                interval=interval,
                                constant=constant)

    def shutdown(self):
        """Shutdown the driver."""

        if hasattr(self.ctx, "close"):
            self.ctx.close()

        if hasattr(self._context, "close"):
            self._context.close()

        self.job_close()
        self.backend_close()
Пример #7
0
class ZmqReceiver(ZmqBase):
    '''
    A ZmqReceiver class will listen on a REP or SUB socket for messages
    and will call the 'handle_incoming_message()' method to process it.
    Subclasses should override that. A response must be implemented for
    REP sockets, but is useless for SUB sockets.
    '''
    def __init__(self,
                 zmq_rep_bind_address: Optional[str] = None,
                 zmq_sub_connect_addresses: Tuple[SubSocketAddress,
                                                  ...] = None,
                 recreate_timeout: Optional[int] = 600,
                 username: Optional[str] = None,
                 password: Optional[str] = None):
        super().__init__()
        self.__context = zmq.Context()
        self.__poller = zmq.Poller()

        self.__sub_sockets = tuple(
            SubSocket(
                ctx=self.__context,
                poller=self.__poller,
                address=address,
                timeout_in_sec=recreate_timeout,
            ) for address in (zmq_sub_connect_addresses or tuple()))

        self.__auth: Optional[ThreadAuthenticator] = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I know (probably because the
            # more secure solutions require two way communication as well)
            self.__auth = ThreadAuthenticator(self.__context)
            self.__auth.start()

            # Instruct authenticator to handle PLAIN requests
            self.__auth.configure_plain(domain='*',
                                        passwords={
                                            username: password,
                                        })

        self.__rep_socket = RepSocket(
            ctx=self.__context,
            poller=self.__poller,
            address=zmq_rep_bind_address,
            auth=self.__auth,
        ) if zmq_rep_bind_address else None

        self.__last_received_message = None
        self.__is_running = False

    @property
    def is_running(self) -> bool:
        return self.__is_running

    def stop(self) -> None:
        '''
        May take up to 60 seconds to actually stop since poller has timeout of
        60 seconds
        '''

        self._info('Closing pub and sub sockets...')
        self.__is_running = False

        if self.__auth is not None:
            self.__auth.stop()

    def _run_rep_socket(self, socks) -> None:
        if self.__rep_socket is None:
            return

        incoming_message = self.__rep_socket.recv_string(socks)
        if incoming_message is None:
            return

        if incoming_message != self.HEARTBEAT_MSG:
            self.__last_received_message = incoming_message

        self._debug('Got info from REP socket')

        try:
            response_message = self.handle_incoming_message(incoming_message, )
            self.__rep_socket.send(response_message)
        except Exception as e:
            self._error(e)

    def _run_sub_sockets(self, socks) -> None:
        for sub_socket in self.__sub_sockets:
            incoming_message = sub_socket.recv_string(socks)

            if incoming_message is None:
                continue

            if incoming_message != self.HEARTBEAT_MSG:
                self.__last_received_message = incoming_message

            self._debug('Got info from SUB socket')

            try:
                self.handle_incoming_message(incoming_message)
            except Exception as e:
                self._error(e)

    def run(self) -> None:
        self.__is_running = True

        while self.__is_running:
            try:
                socks = dict(self.__poller.poll(1000))
            except BaseException as e:
                self._error(e)
                continue

            self._debug('Poll cycle over. checking sockets')
            self._run_rep_socket(socks)
            self._run_sub_sockets(socks)

        if self.__rep_socket:
            self.__rep_socket.destroy()

        for sub_socket in self.__sub_sockets:
            sub_socket.destroy()

    def create_response_message(self,
                                status_code: int,
                                status_message: str,
                                response_message: Optional[str] = None) -> str:
        payload = {
            self.STATUS_CODE: status_code,
            self.STATUS_MSG: status_message,
        }

        if response_message is not None:
            payload[self.RESPONSE_MSG] = response_message

        return json.dumps(payload)

    def handle_incoming_message(self, message: str) -> Optional[str]:
        if message == self.HEARTBEAT_MSG:
            return None

        return self.create_response_message(
            status_code=self.STATUS_CODE_OK,
            status_message=self.STATUS_MSG_OK,
        )

    def get_last_received_message(self) -> Optional[str]:
        return self.__last_received_message

    def get_sub_socket(self, idx: int) -> SubSocket:
        return self.__sub_sockets[idx]
Пример #8
0
class ZmqReceiver(object):
    def __init__(self,
                 zmq_rep_bind_address=None,
                 zmq_sub_connect_addresses=None,
                 recreate_sockets_on_timeout_of_sec=600,
                 username=None,
                 password=None):
        self.context = zmq.Context()
        self.auth = None
        self.last_received_message = None
        self.is_running = False
        self.thread = None
        self.zmq_rep_bind_address = zmq_rep_bind_address
        self.zmq_sub_connect_addresses = zmq_sub_connect_addresses
        self.poller = zmq.Poller()
        self.sub_sockets = []
        self.rep_socket = None
        if username is not None and password is not None:
            # Start an authenticator for this context.
            # Does not work on PUB/SUB as far as I (probably because the more secure solutions
            # require two way communication as well)
            self.auth = ThreadAuthenticator(self.context)
            self.auth.start()
            # Instruct authenticator to handle PLAIN requests
            self.auth.configure_plain(domain='*',
                                      passwords={username: password})

        if self.zmq_sub_connect_addresses:
            for address in self.zmq_sub_connect_addresses:
                self.sub_sockets.append(
                    SubSocket(self.context, self.poller, address,
                              recreate_sockets_on_timeout_of_sec))
        if zmq_rep_bind_address:
            self.rep_socket = RepSocket(self.context, self.poller,
                                        zmq_rep_bind_address, self.auth)

    # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds
    def stop(self):
        self.is_running = False
        logger.info("Closing pub and sub sockets...")
        if self.auth is not None:
            self.auth.stop()

    def run(self):
        self.is_running = True

        while self.is_running:
            socks = dict(self.poller.poll(1000))
            logger.debug("Poll cycle over. checking sockets")
            if self.rep_socket:
                incoming_message = self.rep_socket.recv_string(socks)
                if incoming_message is not None:
                    self.last_received_message = incoming_message
                    try:
                        logger.debug("Got info from REP socket")
                        response_message = self.handle_incoming_message(
                            incoming_message)
                        self.rep_socket.send(response_message)
                    except Exception as e:
                        logger.error(e)
            for sub_socket in self.sub_sockets:
                incoming_message = sub_socket.recv_string(socks)
                if incoming_message is not None:
                    if incoming_message != "zmq_sub_heartbeat":
                        self.last_received_message = incoming_message
                    logger.debug("Got info from SUB socket")
                    try:
                        self.handle_incoming_message(incoming_message)
                    except Exception as e:
                        logger.error(e)

        if self.rep_socket:
            self.rep_socket.destroy()
        for sub_socket in self.sub_sockets:
            sub_socket.destroy()

    def create_response_message(self, status_code, status_message,
                                response_message):
        if response_message is not None:
            return json.dumps({
                "status_code": status_code,
                "status_message": status_message,
                "response_message": response_message
            })
        else:
            return json.dumps({
                "status_code": status_code,
                "status_message": status_message
            })

    def handle_incoming_message(self, message):
        if message != "zmq_sub_heartbeat":
            return self.create_response_message(200, "OK", None)
Пример #9
0
class RpcServer:
    """"""
    def __init__(self) -> None:
        """
        Constructor
        """
        # Save functions dict: key is function name, value is function object
        self._functions: Dict[str, Callable] = {}

        # Zmq port related
        self._context: zmq.Context = zmq.Context()

        # Reply socket (Request–reply pattern)
        self._socket_rep: zmq.Socket = self._context.socket(zmq.REP)

        # Publish socket (Publish–subscribe pattern)
        self._socket_pub: zmq.Socket = self._context.socket(zmq.PUB)

        # Worker thread related
        self._active: bool = False  # RpcServer status
        self._thread: threading.Thread = None  # RpcServer thread
        self._lock: threading.Lock = threading.Lock()

        # Heartbeat related
        self._heartbeat_at: int = None

        # Authenticator used to ensure data security
        self.__authenticator: ThreadAuthenticator = None

    def is_active(self) -> bool:
        """"""
        return self._active

    def start(self,
              rep_address: str,
              pub_address: str,
              username: str = "",
              password: str = "",
              server_secretkey_path: str = "") -> None:
        """
        Start RpcServer
        """
        if self._active:
            return

        # Start authenticator
        if server_secretkey_path:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_curve(
                domain="*", location=zmq.auth.CURVE_ALLOW_ANY)

            publickey, secretkey = zmq.auth.load_certificate(
                server_secretkey_path)

            self.__socket_pub.curve_secretkey = secretkey
            self.__socket_pub.curve_publickey = publickey
            self.__socket_pub.curve_server = True

            self.__socket_rep.curve_secretkey = secretkey
            self.__socket_rep.curve_publickey = publickey
            self.__socket_rep.curve_server = True
        elif username and password:
            self.__authenticator = ThreadAuthenticator(self.__context)
            self.__authenticator.start()
            self.__authenticator.configure_plain(
                domain="*", passwords={username: password})

            self.__socket_pub.plain_server = True
            self.__socket_rep.plain_server = True

        # Bind socket address
        self._socket_rep.bind(rep_address)
        self._socket_pub.bind(pub_address)

        # Start RpcServer status
        self._active = True

        # Start RpcServer thread
        self._thread = threading.Thread(target=self.run)
        self._thread.start()

        # Init heartbeat publish timestamp
        self._heartbeat_at = time() + HEARTBEAT_INTERVAL

    def stop(self) -> None:
        """
        Stop RpcServer
        """
        if not self._active:
            return

        # Stop RpcServer status
        self._active = False

    def join(self) -> None:
        # Wait for RpcServer thread to exit
        if self._thread and self._thread.is_alive():
            self._thread.join()
        self._thread = None

    def run(self) -> None:
        """
        Run RpcServer functions
        """
        while self._active:
            # Poll response socket for 1 second
            n: int = self._socket_rep.poll(1000)
            self.check_heartbeat()

            if not n:
                continue

            # Receive request data from Reply socket
            req = self._socket_rep.recv_pyobj()

            # Get function name and parameters
            name, args, kwargs = req

            # Try to get and execute callable function object; capture exception information if it fails
            try:
                func: Callable = self._functions[name]
                r: Any = func(*args, **kwargs)
                rep: list = [True, r]
            except Exception as e:  # noqa
                rep: list = [False, traceback.format_exc()]

            # send callable response by Reply socket
            self._socket_rep.send_pyobj(rep)

        # Unbind socket address
        self._socket_pub.unbind(self._socket_pub.LAST_ENDPOINT)
        self._socket_rep.unbind(self._socket_rep.LAST_ENDPOINT)
        if self.__authenticator:
            self.__authenticator.stop()

    def publish(self, topic: str, data: Any) -> None:
        """
        Publish data
        """
        with self._lock:
            self._socket_pub.send_pyobj([topic, data])

    def register(self, func: Callable) -> None:
        """
        Register function
        """
        self._functions[func.__name__] = func

    def check_heartbeat(self) -> None:
        """
        Check whether it is required to send heartbeat.
        """
        now: float = time()
        if now >= self._heartbeat_at:
            # Publish heartbeat
            self.publish(HEARTBEAT_TOPIC, now)

            # Update timestamp of next publish
            self._heartbeat_at = now + HEARTBEAT_INTERVAL
Пример #10
0
class CombaZMQAdapter(threading.Thread, CombaBase):
    
    def __init__(self, port):

        self.port = str(port)
        threading.Thread.__init__ (self)
        self.shutdown_event = Event()
        self.context = zmq.Context().instance()
        self.authserver = ThreadAuthenticator(self.context)
        self.loadConfig()
        self.start()

    #------------------------------------------------------------------------------------------#
    def run(self):
        """
        run runs on function start
        """

        self.startAuthserver()
        self.data = ''
        self.socket = self.context.socket(zmq.REP)
        self.socket.plain_server = True
        self.socket.bind("tcp://*:"+self.port)
        self.shutdown_event.clear()
        self.controller = CombaController(self, self.lqs_socket, self.lqs_recorder_socket)
        self.controller.messenger.setMailAddresses(self.get('frommail'), self.get('adminmail'))
        self.can_send = False
        # Process tasks forever
        while not self.shutdown_event.is_set():
            self.data = self.socket.recv()
            self.can_send = True
            data = self.data.split(' ')
            command = str(data.pop(0)) 
            params = "()" if len(data) < 1 else  "('" + "','".join(data) + "')" 
                     
            try: 
                exec"a=self.controller." + command + params  
            
            except SyntaxError:                
                self.controller.message('Warning: Syntax Error')

            except AttributeError:
                print "Warning: Method " + command + " does not exist"
                self.controller.message('Warning: Method ' + command + ' does not exist')
            except TypeError:
                print "Warning: Wrong number of params"
                self.controller.message('Warning: Wrong number of params')
            except:
                print "Warning: Unknown Error"
                self.controller.message('Warning: Unknown Error')

        return

    #------------------------------------------------------------------------------------------#
    def halt(self):
        """
        Stop the server
        """
        if self.shutdown_event.is_set():
            return
        try:
            del self.controller
        except:
            pass
        self.shutdown_event.set()
        result = 'failed'
        try:
            result = self.socket.unbind("tcp://*:"+self.port)
        except:
            pass
        #self.socket.close()

    #------------------------------------------------------------------------------------------#
    def reload(self):
        """
        stop, reload config and startagaing
        """
        if self.shutdown_event.is_set():
            return
        self.loadConfig()
        self.halt()
        time.sleep(3)
        self.run()

    #------------------------------------------------------------------------------------------#
    def send(self,message):
        """
        Send a message to the client
        :param message: string
        """
        if self.can_send:
            self.socket.send(message)
            self.can_send = False

    #------------------------------------------------------------------------------------------#
    def startAuthserver(self):
        """
        Start zmq authentification server
        """
        # stop auth server if running
        if self.authserver.is_alive():
            self.authserver.stop()
        if self.securitylevel > 0:
            # Authentifizierungsserver starten.

            self.authserver.start()

            # Bei security level 2 auch passwort und usernamen verlangen
            if self.securitylevel > 1:
                try:

                    addresses = CombaWhitelist().getList()
                    for address in addresses:
                        self.authserver.allow(address)

                except:
                    pass

            # Instruct authenticator to handle PLAIN requests
            self.authserver.configure_plain(domain='*', passwords=self.getAccounts())

    #------------------------------------------------------------------------------------------#
    def getAccounts(self):
        """
        Get accounts from redis db
        :return: llist - a list of accounts
        """
        accounts = CombaUser().getLogins()
        db = redis.Redis()

        internaccount = db.get('internAccess')
        if not internaccount:
            user = ''.join(random.sample(string.lowercase,10))
            password = ''.join(random.sample(string.lowercase+string.uppercase+string.digits,22))
            db.set('internAccess', user + ':' + password)
            intern = [user, password]
        else:
            intern =  internaccount.split(':')

        accounts[intern[0]] = intern[1]

        return accounts