class MispZmq: message_count = 0 publish_count = 0 monitor_thread = None auth = None socket = None pidfile = None r: redis.StrictRedis namespace: str def __init__(self): self._logger = logging.getLogger() self.tmp_location = Path(__file__).parent.parent / "tmp" self.pidfile = self.tmp_location / "mispzmq.pid" if self.pidfile.exists(): with open(self.pidfile.as_posix()) as f: pid = f.read() if check_pid(pid): raise Exception( "mispzmq already running on PID {}".format(pid)) else: # Cleanup self.pidfile.unlink() if (self.tmp_location / "mispzmq_settings.json").exists(): self._setup() else: raise Exception("The settings file is missing.") def _setup(self): with open((self.tmp_location / "mispzmq_settings.json").as_posix()) as settings_file: self.settings = json.load(settings_file) self.namespace = self.settings["redis_namespace"] self.r = redis.StrictRedis(host=self.settings["redis_host"], db=self.settings["redis_database"], password=self.settings["redis_password"], port=self.settings["redis_port"], decode_responses=True) self.timestamp_settings = time.time() self._logger.debug("Connected to Redis {}:{}/{}".format( self.settings["redis_host"], self.settings["redis_port"], self.settings["redis_database"])) def _setup_zmq(self): context = zmq.Context() if "username" in self.settings and self.settings["username"]: if "password" not in self.settings or not self.settings["password"]: raise Exception( "When username is set, password cannot be empty.") self.auth = ThreadAuthenticator(context) self.auth.start() self.auth.configure_plain(domain="*", passwords={ self.settings["username"]: self.settings["password"] }) else: if self.auth: self.auth.stop() self.auth = None self.socket = context.socket(zmq.PUB) if self.settings["username"]: self.socket.plain_server = True # must come before bind self.socket.bind("tcp://{}:{}".format(self.settings["host"], self.settings["port"])) self._logger.debug("ZMQ listening on tcp://{}:{}".format( self.settings["host"], self.settings["port"])) if self._logger.isEnabledFor(logging.DEBUG): monitor = self.socket.get_monitor_socket() self.monitor_thread = threading.Thread(target=event_monitor, args=(monitor, self._logger)) self.monitor_thread.start() else: if self.monitor_thread: self.socket.disable_monitor() self.monitor_thread = None def _handle_command(self, command): if command == "kill": self._logger.info("Kill command received, shutting down.") self.clean() sys.exit() elif command == "reload": self._logger.info( "Reload command received, reloading settings from file.") self._setup() self._setup_zmq() elif command == "status": self._logger.info( "Status command received, responding with latest stats.") self.r.delete("{}:status".format(self.namespace)) self.r.lpush( "{}:status".format(self.namespace), json.dumps({ "timestamp": time.time(), "timestampSettings": self.timestamp_settings, "publishCount": self.publish_count, "messageCount": self.message_count })) else: self._logger.warning( "Received invalid command '{}'.".format(command)) def _create_pid_file(self): with open(self.pidfile.as_posix(), "w") as f: f.write(str(os.getpid())) def _pub_message(self, topic, data): self.socket.send_string("{} {}".format(topic, data)) def clean(self): if self.monitor_thread: self.socket.disable_monitor() if self.auth: self.auth.stop() if self.socket: self.socket.close() if self.pidfile: self.pidfile.unlink() def main(self): self._create_pid_file() self._setup_zmq() time.sleep(1) status_array = [ "And when you're dead I will be still alive.", "And believe me I am still alive.", "I'm doing science and I'm still alive.", "I feel FANTASTIC and I'm still alive.", "While you're dying I'll be still alive.", ] topics = [ "misp_json", "misp_json_event", "misp_json_attribute", "misp_json_sighting", "misp_json_organisation", "misp_json_user", "misp_json_conversation", "misp_json_object", "misp_json_object_reference", "misp_json_audit", "misp_json_tag", "misp_json_warninglist" ] lists = ["{}:command".format(self.namespace)] for topic in topics: lists.append("{}:data:{}".format(self.namespace, topic)) while True: data = self.r.blpop(lists, timeout=10) if data is None: # redis timeout expired current_time = int(time.time()) time_delta = current_time - int(self.timestamp_settings) status_entry = int(time_delta / 10 % 5) status_message = { "status": status_array[status_entry], "uptime": current_time - int(self.timestamp_settings) } self._pub_message("misp_json_self", json.dumps(status_message)) self._logger.debug( "No message received for 10 seconds, sending ZMQ status message." ) else: key, value = data key = key.replace("{}:".format(self.namespace), "") if key == "command": self._handle_command(value) elif key.startswith("data:"): topic = key.split(":")[1] self._logger.debug( "Received data for topic '{}', sending to ZMQ.".format( topic)) self._pub_message(topic, value) self.message_count += 1 if topic == "misp_json": self.publish_count += 1 else: self._logger.warning( "Received invalid message '{}'.".format(key))
class RpcClient: """""" def __init__(self): """Constructor""" # zmq port related self.__context: zmq.Context = zmq.Context() # Request socket (Request–reply pattern) self.__socket_req: zmq.Socket = self.__context.socket(zmq.REQ) # Subscribe socket (Publish–subscribe pattern) self.__socket_sub: zmq.Socket = self.__context.socket(zmq.SUB) # Worker thread relate, used to process data pushed from server self.__active: bool = False # RpcClient status self.__thread: threading.Thread = None # RpcClient thread self.__lock: threading.Lock = threading.Lock() # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None self._last_received_ping: datetime = datetime.utcnow() @lru_cache(100) def __getattr__(self, name: str): """ Realize remote call function """ # Perform remote call task def dorpc(*args, **kwargs): # Get timeout value from kwargs, default value is 30 seconds if "timeout" in kwargs: timeout = kwargs.pop("timeout") else: timeout = 30000 # Generate request req = [name, args, kwargs] # Send request and wait for response with self.__lock: self.__socket_req.send_pyobj(req) # Timeout reached without any data n = self.__socket_req.poll(timeout) if not n: msg = f"Timeout of {timeout}ms reached for {req}" raise RemoteException(msg) rep = self.__socket_req.recv_pyobj() # Return response if successed; Trigger exception if failed if rep[0]: return rep[1] else: raise RemoteException(rep[1]) return dorpc def start(self, req_address: str, sub_address: str, client_secretkey_path: str = "", server_publickey_path: str = "", username: str = "", password: str = "") -> None: """ Start RpcClient """ if self.__active: return # Start authenticator if client_secretkey_path and server_publickey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( client_secretkey_path) serverkey, _ = zmq.auth.load_certificate(server_publickey_path) self.__socket_sub.curve_secretkey = secretkey self.__socket_sub.curve_publickey = publickey self.__socket_sub.curve_serverkey = serverkey self.__socket_req.curve_secretkey = secretkey self.__socket_req.curve_publickey = publickey self.__socket_req.curve_serverkey = serverkey elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_sub.plain_username = username.encode() self.__socket_sub.plain_password = password.encode() self.__socket_req.plain_username = username.encode() self.__socket_req.plain_password = password.encode() # Connect zmq port self.__socket_req.connect(req_address) self.__socket_sub.connect(sub_address) # Start RpcClient status self.__active = True # Start RpcClient thread self.__thread = threading.Thread(target=self.run) self.__thread.start() self._last_received_ping = datetime.utcnow() def stop(self) -> None: """ Stop RpcClient """ if not self.__active: return # Stop RpcClient status self.__active = False def join(self) -> None: # Wait for RpcClient thread to exit if self.__thread and self.__thread.is_alive(): self.__thread.join() self.__thread = None def run(self) -> None: """ Run RpcClient function """ pull_tolerance = int(KEEP_ALIVE_TOLERANCE.total_seconds() * 1000) while self.__active: if not self.__socket_sub.poll(pull_tolerance): self.on_disconnected() continue # Receive data from subscribe socket topic, data = self.__socket_sub.recv_pyobj(flags=NOBLOCK) if topic == KEEP_ALIVE_TOPIC: self._last_received_ping = data else: # Process data by callable function self.callback(topic, data) # Close socket self.__socket_req.close() self.__socket_sub.close() def callback(self, topic: str, data: Any) -> None: """ Callable function """ raise NotImplementedError def subscribe_topic(self, topic: str) -> None: """ Subscribe data """ self.__socket_sub.setsockopt_string(zmq.SUBSCRIBE, topic) def on_disconnected(self): """ Callback when heartbeat is lost. """ print( "RpcServer has no response over {tolerance} seconds, please check you connection." .format(tolerance=KEEP_ALIVE_TOLERANCE.total_seconds()))
class RpcServer: """""" def __init__(self): """ Constructor """ # Save functions dict: key is fuction name, value is fuction object self.__functions: Dict[str, Any] = {} # Zmq port related self.__context: zmq.Context = zmq.Context() # Reply socket (Request–reply pattern) self.__socket_rep: zmq.Socket = self.__context.socket(zmq.REP) # Publish socket (Publish–subscribe pattern) self.__socket_pub: zmq.Socket = self.__context.socket(zmq.PUB) # Worker thread related self.__active: bool = False # RpcServer status self.__thread: threading.Thread = None # RpcServer thread self.__lock: threading.Lock = threading.Lock() # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None def is_active(self) -> bool: """""" return self.__active def start(self, rep_address: str, pub_address: str, server_secretkey_path: str = "", username: str = "", password: str = "") -> None: """ Start RpcServer """ if self.__active: return # Start authenticator if server_secretkey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( server_secretkey_path) self.__socket_pub.curve_secretkey = secretkey self.__socket_pub.curve_publickey = publickey self.__socket_pub.curve_server = True self.__socket_rep.curve_secretkey = secretkey self.__socket_rep.curve_publickey = publickey self.__socket_rep.curve_server = True elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_pub.plain_server = True self.__socket_rep.plain_server = True # Bind socket address self.__socket_rep.bind(rep_address) self.__socket_pub.bind(pub_address) # Start RpcServer status self.__active = True # Start RpcServer thread self.__thread = threading.Thread(target=self.run) self.__thread.start() def stop(self) -> None: """ Stop RpcServer """ if not self.__active: return # Stop RpcServer status self.__active = False def join(self) -> None: # Wait for RpcServer thread to exit if self.__thread and self.__thread.is_alive(): self.__thread.join() self.__thread = None def run(self) -> None: """ Run RpcServer functions """ start = datetime.utcnow() while self.__active: # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) cur = datetime.utcnow() delta = cur - start if delta >= KEEP_ALIVE_INTERVAL: self.publish(KEEP_ALIVE_TOPIC, cur) if not self.__socket_rep.poll(1000): continue # Receive request data from Reply socket req = self.__socket_rep.recv_pyobj() # Get function name and parameters name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails try: func = self.__functions[name] r = func(*args, **kwargs) rep = [True, r] except Exception as e: # noqa rep = [False, traceback.format_exc()] # send callable response by Reply socket self.__socket_rep.send_pyobj(rep) # Unbind socket address self.__socket_pub.unbind(self.__socket_pub.LAST_ENDPOINT) self.__socket_rep.unbind(self.__socket_rep.LAST_ENDPOINT) def publish(self, topic: str, data: Any) -> None: """ Publish data """ with self.__lock: self.__socket_pub.send_pyobj([topic, data]) def register(self, func: Callable) -> None: """ Register function """ self.__functions[func.__name__] = func
class FeatureComputer(object): def __init__(self, bind_str="tcp://127.0.0.1:5560", parent_model=None, layer=None, logins=None, viewable_layers=None): self.context = zmq.Context.instance() self.auth = ThreadAuthenticator(self.context) self.auth.start() #auth.allow('127.0.0.1') self.auth.configure_plain(domain='*', passwords=logins) self.socket = self.context.socket(zmq.PAIR) self.socket.plain_server = True self.socket.bind(bind_str) self.parent_model = parent_model self.curr_model = parent_model self.viewable_layers = viewable_layers self.config = tf.ConfigProto() self.config.gpu_options.per_process_gpu_memory_fraction = 0.3 self.config.gpu_options.allow_growth = True if not layer: self.layer = 5 #len(self.parent_model.layers) - 1 else: self.layer = layer def change_layer(self, *args, **kwargs): print("Changing layer") print(args, kwargs) self.layer = kwargs.get('layer', self.layer) if not self.viewable_layers: self.curr_model = Model( input=[self.parent_model.layers[0].input], output=[self.parent_model.layers[self.layer].output]) else: self.curr_model = Model( input=[self.parent_model.layers[0].input], output=[self.viewable_layers[self.layer].output]) #set_session(tf.Session(config=self.config)) print(self.layer) self.socket.send_pyobj({ 'type': 'layer_changed', 'success': True }, zmq.NOBLOCK) def get_summary(self, *args, **kwargs): self.socket.send_pyobj({ 'type': 'summary', 'result': None }, zmq.NOBLOCK) def do_predict(self, *args, **kwargs): # TODO: Make this configurable input = kwargs.pop('input', np.zeros((1, 224, 224, 3))) resized = np.float64(cv2.resize(input, (224, 224))) preprocessed = preprocess_input(np.expand_dims(resized, axis=0)) result = self.curr_model.predict(preprocessed, verbose=0) self.socket.send_pyobj({ 'type': 'prediction', 'result': result }, zmq.NOBLOCK) def do_layerinfo(self, *args, **kwargs): self.socket.send_pyobj( { 'type': 'layer_info', 'shape': self.curr_model.compute_output_shape( (1, 224, 224, 3)), 'name': self.parent_model.layers[self.layer].name }, zmq.NOBLOCK) def do_summary(self, *args, **kwargs): if not self.viewable_layers: self.socket.send_pyobj( { 'type': 'summary', 'result': [layer.name for layer in self.parent_model.layers] }, zmq.NOBLOCK) else: self.socket.send_pyobj( { 'type': 'summary', 'result': [layer.name for layer in self.viewable_layers] }, zmq.NOBLOCK) def handle_message(self, message): if message['type'] == "change_layer": self.change_layer(**message) if message['type'] == 'predict': self.do_predict(**message) if message['type'] == 'layer_info': self.do_layerinfo(**message) if message['type'] == 'summary': self.do_summary(**message) def run(self): self.running = True while self.running: message = self.socket.recv_pyobj() self.handle_message(message)
class ZmqReceiver(object): def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None): self.context = zmq.Context() self.auth = None self.last_received_message = None self.is_running = False self.thread = None self.zmq_rep_bind_address = zmq_rep_bind_address self.zmq_sub_connect_addresses = zmq_sub_connect_addresses self.poller = zmq.Poller() self.sub_sockets = [] self.rep_socket = None if username is not None and password is not None: # Start an authenticator for this context. # Does not work on PUB/SUB as far as I (probably because the more secure solutions # require two way communication as well) self.auth = ThreadAuthenticator(self.context) self.auth.start() # Instruct authenticator to handle PLAIN requests self.auth.configure_plain(domain='*', passwords={username: password}) if self.zmq_sub_connect_addresses: for address in self.zmq_sub_connect_addresses: self.sub_sockets.append(SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec)) if zmq_rep_bind_address: self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth) # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds def stop(self): self.is_running = False logger.info("Closing pub and sub sockets...") if self.auth is not None: self.auth.stop() def run(self): self.is_running = True while self.is_running: socks = dict(self.poller.poll(1000)) logger.debug("Poll cycle over. checking sockets") if self.rep_socket: incoming_message = self.rep_socket.recv_string(socks) if incoming_message is not None: self.last_received_message = incoming_message try: logger.debug("Got info from REP socket") response_message = self.handle_incoming_message(incoming_message) self.rep_socket.send(response_message) except Exception as e: logger.error(e) for sub_socket in self.sub_sockets: incoming_message = sub_socket.recv_string(socks) if incoming_message is not None: if incoming_message != "zmq_sub_heartbeat": self.last_received_message = incoming_message logger.debug("Got info from SUB socket") try: self.handle_incoming_message(incoming_message) except Exception as e: logger.error(e) if self.rep_socket: self.rep_socket.destroy() for sub_socket in self.sub_sockets: sub_socket.destroy() def create_response_message(self, status_code, status_message, response_message): if response_message is not None: return json.dumps({"status_code": status_code, "status_message": status_message, "response_message": response_message}) else: return json.dumps({"status_code": status_code, "status_message": status_message}) def handle_incoming_message(self, message): if message != "zmq_sub_heartbeat": return self.create_response_message(200, "OK", None)
class Driver(drivers.BaseDriver): def __init__( self, args, encrypted_traffic_data=None, interface=None, ): """Initialize the Driver. :param args: Arguments parsed by argparse. :type args: Object :param encrypted_traffic: Enable|Disable encrypted traffic. :type encrypted_traffic: Boolean :param interface: The interface instance (client/server) :type interface: Object """ self.thread_processor = multiprocessing.Process self.event = multiprocessing.Event() self.semaphore = multiprocessing.Semaphore self.flushqueue = _FlushQueue self.args = args if getattr(self.args, "zmq_generate_keys", False) is True: self._generate_certificates() print("New certificates generated") raise SystemExit(0) self.encrypted_traffic_data = encrypted_traffic_data mode = getattr(self.args, "mode", None) if mode == "client": self.bind_address = self.args.zmq_server_address elif mode == "server": self.bind_address = self.args.zmq_bind_address else: self.bind_address = "*" self.proto = "tcp" self.connection_string = "{proto}://{addr}".format( proto=self.proto, addr=self.bind_address) if self.encrypted_traffic_data: self.encrypted_traffic = self.encrypted_traffic_data.get("enabled") self.secret_keys_dir = self.encrypted_traffic_data.get( "secret_keys_dir") self.public_keys_dir = self.encrypted_traffic_data.get( "public_keys_dir") else: self.encrypted_traffic = False self.secret_keys_dir = None self.public_keys_dir = None self._context = zmq.Context() self.ctx = self._context.instance() self.poller = zmq.Poller() self.interface = interface super(Driver, self).__init__( args=args, encrypted_traffic_data=self.encrypted_traffic_data, interface=interface, ) self.bind_job = None self.bind_backend = None self.hwm = getattr(self.args, "zmq_highwater_mark", 1024) def __copy__(self): """Return a new copy of the driver.""" return Driver( args=self.args, encrypted_traffic_data=self.encrypted_traffic_data, interface=self.interface, ) def _backend_bind(self): """Bind an address to a backend socket and return the socket. :returns: Object """ bind = self._socket_bind( socket_type=zmq.ROUTER, connection=self.connection_string, port=self.args.backend_port, ) bind.set_hwm(self.hwm) self.log.debug( "Identity [ %s ] backend connect hwm state [ %s ]", self.identity, bind.get_hwm(), ) return bind def _backend_connect(self): """Connect to a backend socket and return the socket. :returns: Object """ self.log.debug("Establishing backend connection.") bind = self._socket_connect( socket_type=zmq.DEALER, connection=self.connection_string, port=self.args.backend_port, ) bind.set_hwm(self.hwm) self.log.debug( "Identity [ %s ] backend connect hwm state [ %s ]", self.identity, bind.get_hwm(), ) return bind def _bind_check(self, bind, interval=1, constant=1000): """Return True if a bind type contains work ready. :param bind: A given Socket bind to identify. :type bind: Object :param interval: Exponential Interval used to determine the polling duration for a given socket. :type interval: Integer :param constant: Constant time used to poll for new jobs. :type constant: Integer :returns: Object """ socks = dict(self.poller.poll(interval * constant)) if socks.get(bind) == zmq.POLLIN: return True else: return False def _close(self, socket): if socket is None: return try: socket.close(linger=2) close_time = time.time() while not socket.closed: if time.time() - close_time > 60: raise TimeoutError( "Job [ {} ] failed to close transfer socket".format( self.job_id)) else: socket.close(linger=2) time.sleep(1) except Exception as e: self.log.error( "Ran into an exception while closing the socket %s", str(e), ) else: self.log.debug("Backend socket closed") def _generate_certificates(self, base_dir="/etc/directord"): """Generate client and server CURVE certificate files. :param base_dir: Directord configuration path. :type base_dir: String """ keys_dir = os.path.join(base_dir, "certificates") public_keys_dir = os.path.join(base_dir, "public_keys") secret_keys_dir = os.path.join(base_dir, "private_keys") for item in [keys_dir, public_keys_dir, secret_keys_dir]: os.makedirs(item, exist_ok=True) # Run certificate backup self._move_certificates(directory=public_keys_dir, backup=True) self._move_certificates(directory=secret_keys_dir, backup=True, suffix=".key_secret") # create new keys in certificates dir for item in ["server", "client"]: self._key_generate(keys_dir=keys_dir, key_type=item) # Move generated certificates in place self._move_certificates( directory=keys_dir, target_directory=public_keys_dir, suffix=".key", ) self._move_certificates( directory=keys_dir, target_directory=secret_keys_dir, suffix=".key_secret", ) def _job_bind(self): """Bind an address to a job socket and return the socket. :returns: Object """ return self._socket_bind( socket_type=zmq.ROUTER, connection=self.connection_string, port=self.args.job_port, ) def _job_connect(self): """Connect to a job socket and return the socket. :returns: Object """ self.log.debug("Establishing Job connection.") return self._socket_connect( socket_type=zmq.DEALER, connection=self.connection_string, port=self.args.job_port, ) def _key_generate(self, keys_dir, key_type): """Generate certificate. :param keys_dir: Full Directory path where a given key will be stored. :type keys_dir: String :param key_type: Key type to be generated. :type key_type: String """ zmq_auth.create_certificates(keys_dir, key_type) @staticmethod def _move_certificates(directory, target_directory=None, backup=False, suffix=".key"): """Move certificates when required. :param directory: Set the origin path. :type directory: String :param target_directory: Set the target path. :type target_directory: String :param backup: Enable file backup before moving. :type backup: Boolean :param suffix: Set the search suffix :type suffix: String """ for item in os.listdir(directory): if backup: target_file = "{}.bak".format(os.path.basename(item)) else: target_file = os.path.basename(item) if item.endswith(suffix): os.rename( os.path.join(directory, item), os.path.join(target_directory or directory, target_file), ) def _socket_bind(self, socket_type, connection, port, poller_type=None): """Return a socket object which has been bound to a given address. When the socket_type is not PUB or PUSH, the bound socket will also be registered with self.poller as defined within the Interface class. :param socket_type: Set the Socket type, typically defined using a ZeroMQ constant. :type socket_type: Integer :param connection: Set the Address information used for the bound socket. :type connection: String :param port: Define the port which the socket will be bound to. :type port: Integer :param poller_type: Set the Socket type, typically defined using a ZeroMQ constant. :type poller_type: Integer :returns: Object """ if poller_type is None: poller_type = zmq.POLLIN bind = self._socket_context(socket_type=socket_type) auth_enabled = (self.args.zmq_shared_key or self.args.zmq_curve_encryption) if auth_enabled: self.auth = ThreadAuthenticator(self.ctx, log=self.log) self.auth.start() self.auth.allow() if self.args.zmq_shared_key: # Enables basic auth self.auth.configure_plain( domain="*", passwords={"admin": self.args.zmq_shared_key}) bind.plain_server = True # Enable shared key authentication self.log.info("Shared key authentication enabled.") elif self.args.zmq_curve_encryption: server_secret_file = os.path.join(self.secret_keys_dir, "server.key_secret") for item in [ self.public_keys_dir, self.secret_keys_dir, server_secret_file, ]: if not os.path.exists(item): raise SystemExit( "The required path [ {} ] does not exist. Have" " you generated your keys?".format(item)) self.auth.configure_curve(domain="*", location=self.public_keys_dir) try: server_public, server_secret = zmq_auth.load_certificate( server_secret_file) except OSError as e: self.log.error( "Failed to load certificates: %s, Configuration: %s", str(e), vars(self.args), ) raise SystemExit("Failed to load certificates") else: bind.curve_secretkey = server_secret bind.curve_publickey = server_public bind.curve_server = True # Enable curve authentication bind.bind("{connection}:{port}".format( connection=connection, port=port, )) if socket_type not in [zmq.PUB]: self.poller.register(bind, poller_type) return bind def _socket_connect(self, socket_type, connection, port, poller_type=None): """Return a socket object which has been bound to a given address. > A connection back to the server will wait 10 seconds for an ack before going into a retry loop. This is done to forcefully cycle the connection object to reset. :param socket_type: Set the Socket type, typically defined using a ZeroMQ constant. :type socket_type: Integer :param connection: Set the Address information used for the bound socket. :type connection: String :param port: Define the port which the socket will be bound to. :type port: Integer :param poller_type: Set the Socket type, typically defined using a ZeroMQ constant. :type poller_type: Integer :returns: Object """ if poller_type is None: poller_type = zmq.POLLIN bind = self._socket_context(socket_type=socket_type) if self.args.zmq_shared_key: bind.plain_username = b"admin" # User is hard coded. bind.plain_password = self.args.zmq_shared_key.encode() self.log.info("Shared key authentication enabled.") elif self.args.zmq_curve_encryption: client_secret_file = os.path.join(self.secret_keys_dir, "client.key_secret") server_public_file = os.path.join(self.public_keys_dir, "server.key") for item in [ self.public_keys_dir, self.secret_keys_dir, client_secret_file, server_public_file, ]: if not os.path.exists(item): raise SystemExit( "The required path [ {} ] does not exist. Have" " you generated your keys?".format(item)) try: client_public, client_secret = zmq_auth.load_certificate( client_secret_file) server_public, _ = zmq_auth.load_certificate( server_public_file) except OSError as e: self.log.error( "Error while loading certificates: %s. Configuration: %s", str(e), vars(self.args), ) raise SystemExit("Failed to load keys.") else: bind.curve_secretkey = client_secret bind.curve_publickey = client_public bind.curve_serverkey = server_public if socket_type == zmq.SUB: bind.setsockopt_string(zmq.SUBSCRIBE, self.identity) else: bind.setsockopt_string(zmq.IDENTITY, self.identity) self.poller.register(bind, poller_type) bind.connect("{connection}:{port}".format( connection=connection, port=port, )) self.log.info("Socket connected to [ %s ].", connection) return bind def _socket_context(self, socket_type): """Create socket context and return a bind object. :param socket_type: Set the Socket type, typically defined using a ZeroMQ constant. :type socket_type: Integer :returns: Object """ bind = self.ctx.socket(socket_type) bind.linger = getattr(self.args, "heartbeat_interval", 60) hwm = int(self.hwm * 4) try: bind.sndhwm = bind.rcvhwm = hwm except AttributeError: bind.hwm = hwm bind.set_hwm(hwm) bind.setsockopt(zmq.SNDHWM, hwm) bind.setsockopt(zmq.RCVHWM, hwm) if socket_type == zmq.ROUTER: bind.setsockopt(zmq.ROUTER_MANDATORY, 1) return bind @staticmethod def _socket_recv(socket, nonblocking=False): """Receive a message over a ZM0 socket. The message specification for server is as follows. [ b"Identity" b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] The message specification for client is as follows. [ b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] All message parts are byte encoded. All possible control characters are defined within the Interface class. For more on control characters review the following URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl). :param socket: ZeroMQ socket object. :type socket: Object :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean """ if nonblocking: flags = zmq.NOBLOCK else: flags = 0 return socket.recv_multipart(flags=flags) @tenacity.retry( retry=tenacity.retry_if_exception_type(Exception), wait=tenacity.wait_fixed(5), before_sleep=tenacity.before_sleep_log( logger.getLogger(name="directord"), logging.WARN), ) def _socket_send( self, socket, identity=None, msg_id=None, control=None, command=None, data=None, info=None, stderr=None, stdout=None, nonblocking=False, ): """Send a message over a ZM0 socket. The message specification for server is as follows. [ b"Identity" b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] The message specification for client is as follows. [ b"ID", b"ASCII Control Characters", b"command", b"data", b"info", b"stderr", b"stdout", ] All message information is assumed to be byte encoded. All possible control characters are defined within the Interface class. For more on control characters review the following URL(https://donsnotes.com/tech/charsets/ascii.html#cntrl). :param socket: ZeroMQ socket object. :type socket: Object :param identity: Target where message will be sent. :type identity: Bytes :param msg_id: ID information for a given message. If no ID is provided a UUID will be generated. :type msg_id: Bytes :param control: ASCII control charaters. :type control: Bytes :param command: Command definition for a given message. :type command: Bytes :param data: Encoded data that will be transmitted. :type data: Bytes :param info: Encoded information that will be transmitted. :type info: Bytes :param stderr: Encoded error information from a command. :type stderr: Bytes :param stdout: Encoded output information from a command. :type stdout: Bytes :param nonblocking: Enable non-blocking send. :type nonblocking: Boolean :returns: Object """ def _encoder(item): try: return item.encode() except AttributeError: return item if not msg_id: msg_id = utils.get_uuid() if not control: control = self.nullbyte if not command: command = self.nullbyte if not data: data = self.nullbyte if not info: info = self.nullbyte if not stderr: stderr = self.nullbyte if not stdout: stdout = self.nullbyte message_parts = [msg_id, control, command, data, info, stderr, stdout] if identity: message_parts.insert(0, identity) message_parts = [_encoder(i) for i in message_parts] if nonblocking: flags = zmq.NOBLOCK else: flags = 0 try: return socket.send_multipart(message_parts, flags=flags) except Exception as e: self.log.warn("Failed to send message to [ %s ]", identity) raise e def _recv(self, socket, nonblocking=False): """Receive message. :param socket: ZeroMQ socket object. :type socket: Object :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean :returns: Tuple """ recv_obj = self._socket_recv(socket=socket, nonblocking=nonblocking) return tuple([i.decode() for i in recv_obj]) def backend_recv(self, nonblocking=False): """Receive a transfer message. :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean :returns: Tuple """ return self._recv(socket=self.bind_backend, nonblocking=nonblocking) def backend_init(self): """Initialize the backend socket. For server mode, this is a bound local socket. For client mode, it is a connection to the server socket. :returns: Object """ if self.args.mode == "server": self.bind_backend = self._backend_bind() else: self.bind_backend = self._backend_connect() def backend_close(self): """Close the backend socket.""" self._close(socket=self.bind_backend) def backend_check(self, interval=1, constant=1000): """Return True if the backend contains work ready. :param bind: A given Socket bind to identify. :type bind: Object :param interval: Exponential Interval used to determine the polling duration for a given socket. :type interval: Integer :param constant: Constant time used to poll for new jobs. :type constant: Integer :returns: Object """ return self._bind_check(bind=self.bind_backend, interval=interval, constant=constant) def backend_send(self, *args, **kwargs): """Send a job message. * All args and kwargs are passed through to the socket send. :returns: Object """ kwargs["socket"] = self.bind_backend return self._socket_send(*args, **kwargs) @staticmethod def get_lock(): """Returns a thread lock.""" return multiprocessing.Lock() def heartbeat_send(self, host_uptime=None, agent_uptime=None, version=None, driver=None): """Send a heartbeat. :param host_uptime: Sender uptime :type host_uptime: String :param agent_uptime: Sender agent uptime :type agent_uptime: String :param version: Sender directord version :type version: String :param version: Driver information :type version: String """ job_id = utils.get_uuid() self.log.info( "Job [ %s ] sending heartbeat from [ %s ] to server", job_id, self.identity, ) return self.job_send( control=self.heartbeat_notice, msg_id=job_id, data=json.dumps({ "job_id": job_id, "version": version, "host_uptime": host_uptime, "agent_uptime": agent_uptime, "machine_id": self.machine_id, "driver": driver, }), ) def job_send(self, *args, **kwargs): """Send a job message. * All args and kwargs are passed through to the socket send. :returns: Object """ kwargs["socket"] = self.bind_job return self._socket_send(*args, **kwargs) def job_recv(self, nonblocking=False): """Receive a transfer message. :param nonblocking: Enable non-blocking receve. :type nonblocking: Boolean :returns: Tuple """ return self._recv(socket=self.bind_job, nonblocking=nonblocking) def job_init(self): """Initialize the job socket. For server mode, this is a bound local socket. For client mode, it is a connection to the server socket. :returns: Object """ if self.args.mode == "server": self.bind_job = self._job_bind() else: self.bind_job = self._job_connect() def job_close(self): """Close the job socket.""" self._close(socket=self.bind_job) def job_check(self, interval=1, constant=1000): """Return True if a job contains work ready. :param bind: A given Socket bind to identify. :type bind: Object :param interval: Exponential Interval used to determine the polling duration for a given socket. :type interval: Integer :param constant: Constant time used to poll for new jobs. :type constant: Integer :returns: Object """ return self._bind_check(bind=self.bind_job, interval=interval, constant=constant) def shutdown(self): """Shutdown the driver.""" if hasattr(self.ctx, "close"): self.ctx.close() if hasattr(self._context, "close"): self._context.close() self.job_close() self.backend_close()
class ZmqReceiver(ZmqBase): ''' A ZmqReceiver class will listen on a REP or SUB socket for messages and will call the 'handle_incoming_message()' method to process it. Subclasses should override that. A response must be implemented for REP sockets, but is useless for SUB sockets. ''' def __init__(self, zmq_rep_bind_address: Optional[str] = None, zmq_sub_connect_addresses: Tuple[SubSocketAddress, ...] = None, recreate_timeout: Optional[int] = 600, username: Optional[str] = None, password: Optional[str] = None): super().__init__() self.__context = zmq.Context() self.__poller = zmq.Poller() self.__sub_sockets = tuple( SubSocket( ctx=self.__context, poller=self.__poller, address=address, timeout_in_sec=recreate_timeout, ) for address in (zmq_sub_connect_addresses or tuple())) self.__auth: Optional[ThreadAuthenticator] = None if username is not None and password is not None: # Start an authenticator for this context. # Does not work on PUB/SUB as far as I know (probably because the # more secure solutions require two way communication as well) self.__auth = ThreadAuthenticator(self.__context) self.__auth.start() # Instruct authenticator to handle PLAIN requests self.__auth.configure_plain(domain='*', passwords={ username: password, }) self.__rep_socket = RepSocket( ctx=self.__context, poller=self.__poller, address=zmq_rep_bind_address, auth=self.__auth, ) if zmq_rep_bind_address else None self.__last_received_message = None self.__is_running = False @property def is_running(self) -> bool: return self.__is_running def stop(self) -> None: ''' May take up to 60 seconds to actually stop since poller has timeout of 60 seconds ''' self._info('Closing pub and sub sockets...') self.__is_running = False if self.__auth is not None: self.__auth.stop() def _run_rep_socket(self, socks) -> None: if self.__rep_socket is None: return incoming_message = self.__rep_socket.recv_string(socks) if incoming_message is None: return if incoming_message != self.HEARTBEAT_MSG: self.__last_received_message = incoming_message self._debug('Got info from REP socket') try: response_message = self.handle_incoming_message(incoming_message, ) self.__rep_socket.send(response_message) except Exception as e: self._error(e) def _run_sub_sockets(self, socks) -> None: for sub_socket in self.__sub_sockets: incoming_message = sub_socket.recv_string(socks) if incoming_message is None: continue if incoming_message != self.HEARTBEAT_MSG: self.__last_received_message = incoming_message self._debug('Got info from SUB socket') try: self.handle_incoming_message(incoming_message) except Exception as e: self._error(e) def run(self) -> None: self.__is_running = True while self.__is_running: try: socks = dict(self.__poller.poll(1000)) except BaseException as e: self._error(e) continue self._debug('Poll cycle over. checking sockets') self._run_rep_socket(socks) self._run_sub_sockets(socks) if self.__rep_socket: self.__rep_socket.destroy() for sub_socket in self.__sub_sockets: sub_socket.destroy() def create_response_message(self, status_code: int, status_message: str, response_message: Optional[str] = None) -> str: payload = { self.STATUS_CODE: status_code, self.STATUS_MSG: status_message, } if response_message is not None: payload[self.RESPONSE_MSG] = response_message return json.dumps(payload) def handle_incoming_message(self, message: str) -> Optional[str]: if message == self.HEARTBEAT_MSG: return None return self.create_response_message( status_code=self.STATUS_CODE_OK, status_message=self.STATUS_MSG_OK, ) def get_last_received_message(self) -> Optional[str]: return self.__last_received_message def get_sub_socket(self, idx: int) -> SubSocket: return self.__sub_sockets[idx]
class ZmqReceiver(object): def __init__(self, zmq_rep_bind_address=None, zmq_sub_connect_addresses=None, recreate_sockets_on_timeout_of_sec=600, username=None, password=None): self.context = zmq.Context() self.auth = None self.last_received_message = None self.is_running = False self.thread = None self.zmq_rep_bind_address = zmq_rep_bind_address self.zmq_sub_connect_addresses = zmq_sub_connect_addresses self.poller = zmq.Poller() self.sub_sockets = [] self.rep_socket = None if username is not None and password is not None: # Start an authenticator for this context. # Does not work on PUB/SUB as far as I (probably because the more secure solutions # require two way communication as well) self.auth = ThreadAuthenticator(self.context) self.auth.start() # Instruct authenticator to handle PLAIN requests self.auth.configure_plain(domain='*', passwords={username: password}) if self.zmq_sub_connect_addresses: for address in self.zmq_sub_connect_addresses: self.sub_sockets.append( SubSocket(self.context, self.poller, address, recreate_sockets_on_timeout_of_sec)) if zmq_rep_bind_address: self.rep_socket = RepSocket(self.context, self.poller, zmq_rep_bind_address, self.auth) # May take up to 60 seconds to actually stop since poller has timeout of 60 seconds def stop(self): self.is_running = False logger.info("Closing pub and sub sockets...") if self.auth is not None: self.auth.stop() def run(self): self.is_running = True while self.is_running: socks = dict(self.poller.poll(1000)) logger.debug("Poll cycle over. checking sockets") if self.rep_socket: incoming_message = self.rep_socket.recv_string(socks) if incoming_message is not None: self.last_received_message = incoming_message try: logger.debug("Got info from REP socket") response_message = self.handle_incoming_message( incoming_message) self.rep_socket.send(response_message) except Exception as e: logger.error(e) for sub_socket in self.sub_sockets: incoming_message = sub_socket.recv_string(socks) if incoming_message is not None: if incoming_message != "zmq_sub_heartbeat": self.last_received_message = incoming_message logger.debug("Got info from SUB socket") try: self.handle_incoming_message(incoming_message) except Exception as e: logger.error(e) if self.rep_socket: self.rep_socket.destroy() for sub_socket in self.sub_sockets: sub_socket.destroy() def create_response_message(self, status_code, status_message, response_message): if response_message is not None: return json.dumps({ "status_code": status_code, "status_message": status_message, "response_message": response_message }) else: return json.dumps({ "status_code": status_code, "status_message": status_message }) def handle_incoming_message(self, message): if message != "zmq_sub_heartbeat": return self.create_response_message(200, "OK", None)
class RpcServer: """""" def __init__(self) -> None: """ Constructor """ # Save functions dict: key is function name, value is function object self._functions: Dict[str, Callable] = {} # Zmq port related self._context: zmq.Context = zmq.Context() # Reply socket (Request–reply pattern) self._socket_rep: zmq.Socket = self._context.socket(zmq.REP) # Publish socket (Publish–subscribe pattern) self._socket_pub: zmq.Socket = self._context.socket(zmq.PUB) # Worker thread related self._active: bool = False # RpcServer status self._thread: threading.Thread = None # RpcServer thread self._lock: threading.Lock = threading.Lock() # Heartbeat related self._heartbeat_at: int = None # Authenticator used to ensure data security self.__authenticator: ThreadAuthenticator = None def is_active(self) -> bool: """""" return self._active def start(self, rep_address: str, pub_address: str, username: str = "", password: str = "", server_secretkey_path: str = "") -> None: """ Start RpcServer """ if self._active: return # Start authenticator if server_secretkey_path: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_curve( domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publickey, secretkey = zmq.auth.load_certificate( server_secretkey_path) self.__socket_pub.curve_secretkey = secretkey self.__socket_pub.curve_publickey = publickey self.__socket_pub.curve_server = True self.__socket_rep.curve_secretkey = secretkey self.__socket_rep.curve_publickey = publickey self.__socket_rep.curve_server = True elif username and password: self.__authenticator = ThreadAuthenticator(self.__context) self.__authenticator.start() self.__authenticator.configure_plain( domain="*", passwords={username: password}) self.__socket_pub.plain_server = True self.__socket_rep.plain_server = True # Bind socket address self._socket_rep.bind(rep_address) self._socket_pub.bind(pub_address) # Start RpcServer status self._active = True # Start RpcServer thread self._thread = threading.Thread(target=self.run) self._thread.start() # Init heartbeat publish timestamp self._heartbeat_at = time() + HEARTBEAT_INTERVAL def stop(self) -> None: """ Stop RpcServer """ if not self._active: return # Stop RpcServer status self._active = False def join(self) -> None: # Wait for RpcServer thread to exit if self._thread and self._thread.is_alive(): self._thread.join() self._thread = None def run(self) -> None: """ Run RpcServer functions """ while self._active: # Poll response socket for 1 second n: int = self._socket_rep.poll(1000) self.check_heartbeat() if not n: continue # Receive request data from Reply socket req = self._socket_rep.recv_pyobj() # Get function name and parameters name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails try: func: Callable = self._functions[name] r: Any = func(*args, **kwargs) rep: list = [True, r] except Exception as e: # noqa rep: list = [False, traceback.format_exc()] # send callable response by Reply socket self._socket_rep.send_pyobj(rep) # Unbind socket address self._socket_pub.unbind(self._socket_pub.LAST_ENDPOINT) self._socket_rep.unbind(self._socket_rep.LAST_ENDPOINT) if self.__authenticator: self.__authenticator.stop() def publish(self, topic: str, data: Any) -> None: """ Publish data """ with self._lock: self._socket_pub.send_pyobj([topic, data]) def register(self, func: Callable) -> None: """ Register function """ self._functions[func.__name__] = func def check_heartbeat(self) -> None: """ Check whether it is required to send heartbeat. """ now: float = time() if now >= self._heartbeat_at: # Publish heartbeat self.publish(HEARTBEAT_TOPIC, now) # Update timestamp of next publish self._heartbeat_at = now + HEARTBEAT_INTERVAL
class CombaZMQAdapter(threading.Thread, CombaBase): def __init__(self, port): self.port = str(port) threading.Thread.__init__ (self) self.shutdown_event = Event() self.context = zmq.Context().instance() self.authserver = ThreadAuthenticator(self.context) self.loadConfig() self.start() #------------------------------------------------------------------------------------------# def run(self): """ run runs on function start """ self.startAuthserver() self.data = '' self.socket = self.context.socket(zmq.REP) self.socket.plain_server = True self.socket.bind("tcp://*:"+self.port) self.shutdown_event.clear() self.controller = CombaController(self, self.lqs_socket, self.lqs_recorder_socket) self.controller.messenger.setMailAddresses(self.get('frommail'), self.get('adminmail')) self.can_send = False # Process tasks forever while not self.shutdown_event.is_set(): self.data = self.socket.recv() self.can_send = True data = self.data.split(' ') command = str(data.pop(0)) params = "()" if len(data) < 1 else "('" + "','".join(data) + "')" try: exec"a=self.controller." + command + params except SyntaxError: self.controller.message('Warning: Syntax Error') except AttributeError: print "Warning: Method " + command + " does not exist" self.controller.message('Warning: Method ' + command + ' does not exist') except TypeError: print "Warning: Wrong number of params" self.controller.message('Warning: Wrong number of params') except: print "Warning: Unknown Error" self.controller.message('Warning: Unknown Error') return #------------------------------------------------------------------------------------------# def halt(self): """ Stop the server """ if self.shutdown_event.is_set(): return try: del self.controller except: pass self.shutdown_event.set() result = 'failed' try: result = self.socket.unbind("tcp://*:"+self.port) except: pass #self.socket.close() #------------------------------------------------------------------------------------------# def reload(self): """ stop, reload config and startagaing """ if self.shutdown_event.is_set(): return self.loadConfig() self.halt() time.sleep(3) self.run() #------------------------------------------------------------------------------------------# def send(self,message): """ Send a message to the client :param message: string """ if self.can_send: self.socket.send(message) self.can_send = False #------------------------------------------------------------------------------------------# def startAuthserver(self): """ Start zmq authentification server """ # stop auth server if running if self.authserver.is_alive(): self.authserver.stop() if self.securitylevel > 0: # Authentifizierungsserver starten. self.authserver.start() # Bei security level 2 auch passwort und usernamen verlangen if self.securitylevel > 1: try: addresses = CombaWhitelist().getList() for address in addresses: self.authserver.allow(address) except: pass # Instruct authenticator to handle PLAIN requests self.authserver.configure_plain(domain='*', passwords=self.getAccounts()) #------------------------------------------------------------------------------------------# def getAccounts(self): """ Get accounts from redis db :return: llist - a list of accounts """ accounts = CombaUser().getLogins() db = redis.Redis() internaccount = db.get('internAccess') if not internaccount: user = ''.join(random.sample(string.lowercase,10)) password = ''.join(random.sample(string.lowercase+string.uppercase+string.digits,22)) db.set('internAccess', user + ':' + password) intern = [user, password] else: intern = internaccount.split(':') accounts[intern[0]] = intern[1] return accounts