class Facilities(furnace_runtime.FurnaceRuntime): """ Main Facilities class. """ def __init__(self, debug, u2, kp, qt, it, ip, fac): """ Constructor, ZMQ connections to the APP partition and backend are built here. :param u2: The UUID for the corresponding app instance. :param kp: The keypair to use, in the form {'be_key': '[path]', 'app_key': '[path]'} :param qt: Custom quotas dict. :param ip: Connect to this proxy IP. :param it: Connect to this proxy TCP port. :param fac: The path to the FAC partition's ZMQ socket. """ super(Facilities, self).__init__(debug=debug) self.debug = debug self.u2 = u2 self.fac_path = fac self.it = it self.ip = ip self.already_shutdown = False self.tprint(DEBUG, "starting Facilities constructor") self.tprint(DEBUG, f"running as UUID={u2}") api = api_implementation._default_implementation_type if api == "python": raise Exception("not running in cpp accelerated mode") self.tprint(DEBUG, f"running in api mode: {api}") self.fmsg_in = facilities_pb2.FacMessage() self.fmsg_out = facilities_pb2.FacMessage() self.pmsg_in = facilities_pb2.FacMessage() self.pmsg_out = facilities_pb2.FacMessage() self.storage = {} self.io_id = 0 self.disk_usage = 0 self.max_disk_usage = int(qt["_furnace_max_disk"]) self.tprint(DEBUG, f"MAX_DISK_USAGE {self.max_disk_usage}") self.context = zmq.Context(io_threads=4) self.poller = zmq.Poller() # connections with the proxy self.tprint(DEBUG, f"loading crypto...") self.auth = ThreadAuthenticator(self.context) # crypto bootstrap self.auth.start() self.auth.configure_curve(domain="*", location=zmq.auth.CURVE_ALLOW_ANY) publisher_public, _ = zmq.auth.load_certificate(kp["be_pub"]) public, secret = zmq.auth.load_certificate(kp["app_key"]) assert self.auth.is_alive() self.tprint(DEBUG, f"loading crypto... DONE") TCP_PXY_RTR = f"tcp://{str(ip)}:{int(it)+0}" TCP_PXY_PUB = f"tcp://{str(ip)}:{int(it)+1}" # connections to the proxy s = f"tcp://{self.ip}:{self.it}" self.dealer_pxy = self.context.socket(zmq.DEALER) # async directed self.didentity = f"{u2}-d" # this must match the FE self.dealer_pxy.identity = self.didentity.encode() self.dealer_pxy.curve_publickey = public self.dealer_pxy.curve_secretkey = secret self.dealer_pxy.curve_serverkey = publisher_public self.tprint(INFO, f"FAC--PXY: Connecting as Dealer to {TCP_PXY_RTR}") self.dealer_pxy.connect(s) self.req_pxy = self.context.socket(zmq.REQ) # sync self.ridentity = f"{u2}-r" self.req_pxy.identity = self.ridentity.encode() self.req_pxy.curve_publickey = public self.req_pxy.curve_secretkey = secret self.req_pxy.curve_serverkey = publisher_public self.req_pxy.connect(s) self.sub_pxy = self.context.socket(zmq.SUB) # async broadcast self.sub_pxy.curve_publickey = public self.sub_pxy.curve_secretkey = secret self.sub_pxy.curve_serverkey = publisher_public self.sub_pxy.setsockopt(zmq.SUBSCRIBE, b"") self.tprint(INFO, f"FAC--PXY: Connecting as Publisher to {TCP_PXY_PUB}") self.sub_pxy.connect(TCP_PXY_PUB) self.poller.register(self.dealer_pxy, zmq.POLLIN) self.poller.register(self.sub_pxy, zmq.POLLIN) # connections with the frontend s = f"ipc://{self.fac_path}{os.sep}fac_rtr" self.rtr_fac = self.context.socket(zmq.ROUTER) # fe sync/async msgs self.tprint(INFO, f"APP--FAC: Binding as Router to {s}") try: self.rtr_fac.bind(s) except zmq.error.ZMQError: self.tprint(ERROR, f"rtr_fac {s} is already in use!") sys.exit() s = f"ipc://{self.fac_path}{os.sep}fac_pub" self.pub_fac = self.context.socket(zmq.PUB) # bcast to fe self.tprint(INFO, f"APP--FAC: Binding as Publisher to {s}") try: self.pub_fac.bind(s) except zmq.error.ZMQError: self.tprint(ERROR, f"pub_fac {s} is already in use!") sys.exit() self.poller.register(self.rtr_fac, zmq.POLLIN) self.tprint(DEBUG, "leaving Facilities constructor") def loop(self): """ Main event loop. Polls (with timeout) on ZMQ sockets waiting for data from the APP partition and backend. :returns: Nothing. """ while True: socks = dict(self.poller.poll(TIMEOUT_FAC)) # inbound async from BE (sync should not arrive here) if self.dealer_pxy in socks: sys.stdout.write("v") sync = ASYNC raw_msg = self.dealer_pxy.recv() self.msgin += 1 raw_out = self.be2fe_process_protobuf(raw_msg, sync) # inbound broadcast from BE if self.sub_pxy in socks: sys.stdout.write("v") raw_msg = self.sub_pxy.recv() self.msgin += 1 self.pmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "BROADCAST FROM PXY\n%s" % (self.pmsg_in, )) index = 0 submsg_list = getattr(self.pmsg_in, "be_msg") m = submsg_list[index] self.pmsg_out.Clear() msgnum = self.pmsg_out.__getattribute__("BE_MSG") self.pmsg_out.type.append(msgnum) submsg = getattr(self.pmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.pmsg_out.SerializeToString() self.tprint(DEBUG, "to FE {len(raw_str)}B {self.pmsg_out}") self.pub_fac.send(raw_str) self.msgout += 1 # inbound from FE if self.rtr_fac in socks: sys.stdout.write("^") pkt = self.rtr_fac.recv_multipart() self.msgin += 1 if len(pkt) == 3: # sync sync = SYNC ident, empty, raw_msg = pkt raw_out = self.fe2be_process_protobuf(raw_msg, sync) self.rtr_fac.send_multipart( [ident.encode(), b"", raw_out.encode()]) self.msgout += 1 elif len(pkt) == 2: # async sync = ASYNC ident, raw_msg = pkt self.fe2be_process_protobuf(raw_msg, sync) if not socks: sys.stdout.write(".") sys.stdout.flush() self.tick += 1 def be2fe_process_protobuf(self, raw_msg, sync): """ Processes messages sent by the backend. :returns: Nothing (modifies class's protobufs). """ self.pmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "FROM BE\n%s" % (self.pmsg_in, )) for t in self.pmsg_in.type: msgtype = facilities_pb2.FacMessage.Type.Name( t) # msgtype = BE_MSG msglist = getattr(self.pmsg_in, msgtype.lower()) m = msglist.pop(0) # the particular message if msgtype == "BE_MSG": self.fmsg_out.Clear() msgnum = self.fmsg_out.__getattribute__("BE_MSG") self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.fmsg_out.SerializeToString() self.tprint( DEBUG, "to FE {self.didentity} {len(raw_str)}B {self.fmsg_out}") self.rtr_fac.send_multipart([self.didentity.encode(), raw_str]) self.msgout += 1 return # Split off between pass-through to BE or fac-provided IO. def fe2be_process_protobuf(self, raw_msg, sync): """ Processes messages sent by the app. Some messages are processed locally (IO). Others are forwarded to the backend. :returns: Raw serialized protobuf. """ self.fmsg_in.ParseFromString(raw_msg) self.tprint(DEBUG, "FROM FE\n%s" % (self.fmsg_in, )) self.fmsg_out.Clear() for t in self.fmsg_in.type: msgtype = facilities_pb2.FacMessage.Type.Name( t) # msgtype = BE_MSG msglist = getattr(self.fmsg_in, msgtype.lower()) m = msglist.pop(0) # the particular message if msgtype == "BE_MSG": # Pass-through to BE self.fe2be_process_msg(m, sync) elif msgtype == "GET": # IO self.process_get(m) elif msgtype == "SET": # IO self.process_set(m) self.tprint(DEBUG, "sending {self.fmsg_out}") raw_out = self.fmsg_out.SerializeToString() return raw_out def fe2be_process_msg(self, m, sync): """ Processes messages sent by the app. All these are destined to the backend. :returns: Nothing (modifies class's protobufs). """ self.pmsg_out.Clear() msgnum = self.pmsg_out.__getattribute__("BE_MSG") self.pmsg_out.type.append(msgnum) submsg = getattr(self.pmsg_out, "be_msg").add() submsg.status = m.status submsg.value = m.value raw_str = self.pmsg_out.SerializeToString() self.tprint(DEBUG, "sending to BE %dB %s" % (len(raw_str), self.pmsg_out)) if sync == SYNC: self.req_pxy.send(raw_str) self.msgout += 1 self.pmsg_in.ParseFromString(self.req_pxy.recv()) self.msgin += 1 self.tprint(DEBUG, "recv from BE %dB %s" % (len(raw_str), self.pmsg_in)) index = 0 submsg_list = getattr(self.pmsg_in, "be_msg_ret") status = submsg_list[index].status val = submsg_list[index].value self.fmsg_out.Clear() msgnum = self.fmsg_out.__getattribute__("BE_MSG_RET") self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, "be_msg_ret").add() submsg.status = status submsg.value = val # raw_str = self.pmsg_out.SerializeToString() # return raw_str return elif sync == ASYNC: self.dealer_pxy.send(raw_str) self.msgout += 1 return def process_get(self, m): """ The app is asking for previously stored data. :param m: The protobuf message sent by the app. :returns: Nothing (modifies class's protobufs). """ key = str(m.key) msgtype = "GET_RET" msgnum = self.fmsg_out.__getattribute__(msgtype) self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, msgtype.lower()).add() submsg.key = key if not self.get_name_ok(key): submsg.value = "" submsg.result = VMI_FAILURE return try: io_id = self.storage[key] with open(FAC_IO_DIR + str(io_id), "rb") as fh: submsg.value = fh.read().decode() submsg.result = VMI_SUCCESS except IndexError: submsg.value = "" submsg.result = VMI_FAILURE return def process_set(self, m): """ The app has requested we store this data. :param m: The protobuf message sent by the app. :returns: Nothing (modifies class's protobufs). """ key = str(m.key) value = str(m.value) msgtype = "SET_RET" msgnum = self.fmsg_out.__getattribute__(msgtype) self.fmsg_out.type.append(msgnum) submsg = getattr(self.fmsg_out, msgtype.lower()).add() submsg.key = key if not self.get_name_ok(key): self.tprint(ERROR, f"key not allowed") submsg.value = "" submsg.result = VMI_FAILURE return self.tprint( DEBUG, f"cur={self.disk_usage}, req={len(value.encode())}, max={MAX_DISK_USAGE}", ) if self.disk_usage + len(value.encode()) > MAX_DISK_USAGE: self.tprint(ERROR, f"disk usage exceeded") submsg.result = VMI_FAILURE return try: # io_id = str(uuid.uuid4()) self.tprint(DEBUG, f"checking for existing key {key}") # overwrite a key try: existing_id = self.storage[key] self.tprint(DEBUG, f"found!") except KeyError: self.tprint(DEBUG, f"not found") pass else: s = os.path.getsize(FAC_IO_DIR + str(existing_id)) self.tprint(DEBUG, f"deleting existing file of size={s}") self.disk_usage -= s os.remove(FAC_IO_DIR + str(existing_id)) del self.storage[key] io_id = self.io_id self.io_id += 1 self.tprint(DEBUG, f"will assign key {io_id}") self.tprint(DEBUG, f"writing file of size {len(value.encode())}") self.disk_usage += len(value) self.tprint(DEBUG, f"disk usage at {self.disk_usage}") self.storage[key] = io_id with open(FAC_IO_DIR + str(io_id), "wb") as fh: fh.write(value.encode()) submsg.result = VMI_SUCCESS except: submsg.result = VMI_FAILURE return def get_name_ok(self, n): """ :param n: The key (must be alphanum and <=16 characters). :returns: True if in accordance with naming convention, False if otherwise. """ return 0 < len(n) < 17 and re.match("^[\w-]+$", n) def check_encoding(self, val): """ Fun. Now with more Unicode. :param val: Is this variable a string? :returns: Nothing. :raises: TypeError if val is not a string. """ if not isinstance(val, str): raise TypeError("Error: value is not of type str") def shutdown(self): """ Exit cleanly. :returns: Nothing. """ if self.already_shutdown: self.tprint(ERROR, "already shutdown?") return self.already_shutdown = True for key in self.storage: io_id = self.storage[key] self.tprint(INFO, "removing io: %s -> %s" % (key, io_id)) os.remove(FAC_IO_DIR + str(io_id)) self.auth.stop() self.poller.unregister(self.dealer_pxy) self.poller.unregister(self.sub_pxy) self.dealer_pxy.close() self.sub_pxy.close() self.poller.unregister(self.rtr_fac) self.rtr_fac.close() self.pub_fac.close() self.context.destroy() super(Facilities, self).shutdown()
class CombaZMQAdapter(threading.Thread, CombaBase): def __init__(self, port): self.port = str(port) threading.Thread.__init__ (self) self.shutdown_event = Event() self.context = zmq.Context().instance() self.authserver = ThreadAuthenticator(self.context) self.loadConfig() self.start() #------------------------------------------------------------------------------------------# def run(self): """ run runs on function start """ self.startAuthserver() self.data = '' self.socket = self.context.socket(zmq.REP) self.socket.plain_server = True self.socket.bind("tcp://*:"+self.port) self.shutdown_event.clear() self.controller = CombaController(self, self.lqs_socket, self.lqs_recorder_socket) self.controller.messenger.setMailAddresses(self.get('frommail'), self.get('adminmail')) self.can_send = False # Process tasks forever while not self.shutdown_event.is_set(): self.data = self.socket.recv() self.can_send = True data = self.data.split(' ') command = str(data.pop(0)) params = "()" if len(data) < 1 else "('" + "','".join(data) + "')" try: exec"a=self.controller." + command + params except SyntaxError: self.controller.message('Warning: Syntax Error') except AttributeError: print "Warning: Method " + command + " does not exist" self.controller.message('Warning: Method ' + command + ' does not exist') except TypeError: print "Warning: Wrong number of params" self.controller.message('Warning: Wrong number of params') except: print "Warning: Unknown Error" self.controller.message('Warning: Unknown Error') return #------------------------------------------------------------------------------------------# def halt(self): """ Stop the server """ if self.shutdown_event.is_set(): return try: del self.controller except: pass self.shutdown_event.set() result = 'failed' try: result = self.socket.unbind("tcp://*:"+self.port) except: pass #self.socket.close() #------------------------------------------------------------------------------------------# def reload(self): """ stop, reload config and startagaing """ if self.shutdown_event.is_set(): return self.loadConfig() self.halt() time.sleep(3) self.run() #------------------------------------------------------------------------------------------# def send(self,message): """ Send a message to the client :param message: string """ if self.can_send: self.socket.send(message) self.can_send = False #------------------------------------------------------------------------------------------# def startAuthserver(self): """ Start zmq authentification server """ # stop auth server if running if self.authserver.is_alive(): self.authserver.stop() if self.securitylevel > 0: # Authentifizierungsserver starten. self.authserver.start() # Bei security level 2 auch passwort und usernamen verlangen if self.securitylevel > 1: try: addresses = CombaWhitelist().getList() for address in addresses: self.authserver.allow(address) except: pass # Instruct authenticator to handle PLAIN requests self.authserver.configure_plain(domain='*', passwords=self.getAccounts()) #------------------------------------------------------------------------------------------# def getAccounts(self): """ Get accounts from redis db :return: llist - a list of accounts """ accounts = CombaUser().getLogins() db = redis.Redis() internaccount = db.get('internAccess') if not internaccount: user = ''.join(random.sample(string.lowercase,10)) password = ''.join(random.sample(string.lowercase+string.uppercase+string.digits,22)) db.set('internAccess', user + ':' + password) intern = [user, password] else: intern = internaccount.split(':') accounts[intern[0]] = intern[1] return accounts
class StupidNode: pubkey = privkey = None channel = "" # subscription filter or something (I think) PORTS = 4 # as we add or remove ports, make sure this is the number of ports a StupidNode uses def __init__(self, endpoint="*", identity=None, keyring=DEFAULT_KEYRING): self.keyring = keyring self.endpoint = (endpoint if isinstance(endpoint, Endpoint) else Endpoint(endpoint)) self.endpoints = list() self.identity = identity or f"{gethostname()}-{self.endpoint.pub}" self.log = logging.getLogger(f"{self.identity}") self.log.debug("begin node setup / creating context") self.ctx = zmq.Context() self.cleartext_ctx = zmq.Context() self.start_auth() self.log.debug("creating sockets") self.pub = self.mk_socket(zmq.PUB) self.router = self.mk_socket(zmq.ROUTER) self.router.router_mandatory = ( 1 # one of the few opts that can be set after bind() ) self.rep = self.mk_socket(zmq.REP, enable_curve=False) self.sub = list() self.dealer = list() self.log.debug("binding sockets") self.bind(self.pub) self.bind(self.router) self.bind(self.rep, enable_curve=False) self.log.debug("registering polling") self.poller = zmq.Poller() self.poller.register(self.router, zmq.POLLIN) self.log.debug("configuring interrupt signal") signal.signal(signal.SIGINT, self.interrupt) self.log.debug("configuring WAI Reply Thread") self._who_are_you_thread = Thread( target=self.who_are_you_reply_machine) self._who_are_you_continue = True self._who_are_you_thread.start() self.route_queue = deque(list(), ROUTE_QUEUE_LEN) self.routes = dict() self.log.debug("node setup complete") def who_are_you_reply_machine(self): while self._who_are_you_continue: if self.rep.poll(200): self.log.debug("wai polled, trying to recv") msg = self.rep.recv() ttype = zmq_socket_type_name(self.rep) self.log.debug('received "%s" over %s socket', msg, ttype) msg = [self.identity.encode(), self.pubkey] self.log.debug('sending "%s" as reply over %s socket', msg, ttype) self.rep.send_multipart(msg) self.log.debug("wai thread seems finished, loop broken") def start_auth(self): self.log.debug("starting auth thread") self.auth = ThreadAuthenticator(self.ctx) self.auth.start() self.auth.allow("127.0.0.1") self.auth.configure_curve(domain="*", location=self.keyring) self.load_or_create_key() @property def key_basename(self): return scrub_identity_name_for_certfile(self.identity) @property def key_filename(self): return os.path.join(self.keyring, self.key_basename + ".key") @property def secret_key_filename(self): return self.key_filename + "_secret" def load_key(self): self.log.debug("loading node key-pair") self.pubkey, self.privkey = zmq.auth.load_certificate( self.secret_key_filename) def load_or_create_key(self): try: self.load_key() except IOError as e: self.log.debug("error loading key: %s", e) self.log.debug("creating node key-pair") os.makedirs(self.keyring, mode=0o0700, exist_ok=True) zmq.auth.create_certificates(self.keyring, self.key_basename) self.load_key() def preprocess_message(self, msg, msg_class=TaggedMessage): if not isinstance(msg, msg_class): if not isinstance(msg, (list, tuple)): msg = (msg, ) msg = msg_class(*msg, name=self.identity) rmsg = repr(msg) emsg = msg.encode() return msg, rmsg, emsg def route_failed(self, msg): if not isinstance(msg, RoutedMessage): raise TypeError("msg must already be a RoutedMessage") msg.failures += 1 if msg.failures <= 5: self.log.debug("(re)queueing %s for later delivery", repr(msg)) if len(self.route_queue) == self.route_queue.maxlen: self.log.error("route_queue full, discarding %s", repr(self.route_queue[0])) self.route_queue.append(msg) else: self.log.error("discarding %s after %d failures", repr(msg), msg.failures) def route_message(self, to, msg): if isinstance(to, StupidNode): to = to.identity if isinstance(to, (list, tuple)): to = to[-1] R = self.routes.get(to) if R: to = (R[0], to) if isinstance(msg, RoutedMessage): msg.to = to else: # preprocess passes *msg to msg_class() -- ie, RoutedMessage(to, *msg) if isinstance(msg, list): msg = tuple(msg) elif not isinstance(msg, tuple): msg = (msg, ) msg = (to, ) + msg tmsg, rmsg, emsg = self.preprocess_message(msg, msg_class=RoutedMessage) self.log.debug("routing message %s -- encoding: %s", rmsg, emsg) try: self.router.send_multipart(emsg) except zmq.error.ZMQError as zmq_e: self.log.debug("route to %s failed: %s", to, zmq_e) if "Host unreachable" not in str(zmq_e): raise self.route_failed(tmsg) def deal_message(self, msg): self.log.debug( "dealing message (actually publishing with no_publish=True)") self.publish_message(msg, no_publish=True) def publish_message(self, msg, no_deal=False, no_deal_to=None, no_publish=False): tmsg, rmsg, emsg = self.preprocess_message(msg) self.log.debug( "publishing message %s no_publish=%s, no_deal=%s, no_deal_to=%s", rmsg, no_publish, no_deal, no_deal_to, ) self.local_workflow(tmsg) if not no_publish: self.pub.send_multipart(emsg) if no_deal: return if no_deal_to is None: ok_send = lambda x: True elif callable(no_deal_to): ok_send = no_deal_to elif isinstance(no_deal_to, zmq.Socket): npt_i = self.dealer.index(no_deal_to) ok_send = lambda x: x != npt_i elif isinstance(no_deal_to, int): ok_send = lambda x: x != no_deal_to elif isinstance(no_deal_to, (list, tuple)): ok_send = lambda x: x not in no_deal_to for i, sock in enumerate(self.dealer): if ok_send(i): self.log.debug("dealing message %s to %s", rmsg, self.endpoints[i]) sock.send_multipart(emsg) else: self.log.debug("not sending %s to %s", rmsg, self.endpoints[i]) def mk_socket(self, stype, enable_curve=True): # defaults: # socket.setsockopt(zmq.LINGER, -1) # infinite # socket.setsockopt(zmq.IDENTITY, None) # socket.setsockopt(zmq.TCP_KEEPALIVE, -1) # socket.setsockopt(zmq.TCP_KEEPALIVE_INTVL, -1) # socket.setsockopt(zmq.TCP_KEEPALIVE_CNT, -1) # socket.setsockopt(zmq.TCP_KEEPALIVE_IDLE, -1) # socket.setsockopt(zmq.RECONNECT_IVL, 100) # socket.setsockopt(zmq.RECONNECT_IVL_MAX, 0) # 0 := always use IVL # the above can be accessed as attributes instead (they are case # insensitive, we choose lower case below so it looks like boring # python) if enable_curve: socket = self.ctx.socket(stype) self.log.debug("create %s socket in crypto context", zmq_socket_type_name(stype)) else: socket = self.cleartext_ctx.socket(stype) self.log.debug("create %s socket in cleartext context", zmq_socket_type_name(stype)) socket.linger = 1 socket.identity = self.identity.encode() socket.reconnect_ivl = 1000 socket.reconnect_ivl_max = 10000 if enable_curve: socket.curve_secretkey = self.privkey socket.curve_publickey = self.pubkey return socket def local_workflow(self, msg): self.log.debug("start local_workflow %s", repr(msg)) msg = self.local_react(msg) if msg: msg = self.all_react(msg) return msg def sub_workflow(self, socket): idx = self.sub.index(socket) enp = self.endpoints[idx] msg = self.sub_receive(socket, idx) self.log.debug("start sub_workflow (idx=%d -> endpoint=%s) %s", idx, enp, repr(msg)) for react in (self.sub_react, self.nonlocal_react, self.all_react): if msg: msg = react(msg, idx=idx) self.log.debug("end sub_workflow") return msg def router_workflow(self): msg = self.router_receive() self.log.debug("start router_workflow %s", repr(msg)) for react in (self.router_react, self.nonlocal_react, self.all_react): if not msg: break msg = react(msg) self.log.debug("end router_workflow") return msg def dealer_workflow(self, socket): idx = self.dealer.index(socket) enp = self.endpoints[idx] msg = self.dealer_receive(socket, idx) self.log.debug("start deal_workflow (idx=%d -> endpoint=%s) %s", idx, enp, repr(msg)) for react in (self.dealer_react, self.nonlocal_react, self.all_react): if not msg: break msg = react(msg, idx=idx) self.log.debug("end deal_workflow") return msg def sub_receive(self, socket, idx): # pylint: disable=unused-argument return TaggedMessage(*socket.recv_multipart()) def dealer_receive(self, socket, idx): # pylint: disable=unused-argument msg = socket.recv_multipart() rm = RoutedMessage.decode(msg) if rm: return rm # dealer's always receive a routed message if it doesn't appear to be # routed, then it's simply intended for us. In that case, build a # tagged message and mark it as non-publish msg = TaggedMessage(*msg) msg.publish_mark = False return msg def router_receive(self): # we ignore the source ID (in '_') and just believe the msg.tag.name ... it's # roughly the same thing anyway _, *msg = self.router.recv_multipart() rm = RoutedMessage.decode(msg) if rm: return rm return TaggedMessage(*msg) def all_react(self, msg, idx=None): # pylint: disable=unused-argument return msg def sub_react(self, msg, idx=None): # pylint: disable=unused-argument return msg def dealer_react(self, msg, idx=None): # pylint: disable=unused-argument return msg def router_react(self, msg): return msg def nonlocal_react(self, msg, idx=None): if isinstance(msg, RoutedMessage): msg = self.routed_react(msg, idx=idx) return msg def local_react(self, msg): return msg def routed_react(self, msg, idx=None): # pylint: disable=unused-argument return False def poll(self, timeo=500, other_cb=None): """Check to see if there's any incoming messages. If anything seems ready to receive, invoke the related workflow or invoke other_cb (if given) on the socket item. """ items = dict(self.poller.poll(timeo)) ret = list() for item in items: if items[item] != zmq.POLLIN: continue if item in self.sub: res = self.sub_workflow(item) elif item in self.dealer: res = self.dealer_workflow(item) elif item is self.router: res = self.router_workflow() elif callable(other_cb): res = other_cb(item) else: res = None if False and isinstance(item, zmq.Socket): self.log.error( "no workflow defined for socket of type %s -- received: %s", zmq_socket_type_name(item), item.recv_multipart(), ) else: self.log.error( "no workflow defined for socket of type %s -- regarding as fatal", zmq_socket_type_name(item), ) # note: this normally doesn't trigger an exit... thanks threading raise Exception("unhandled poll item") if isinstance(res, TaggedMessage): ret.append(res) return ret def interrupt(self, signo, eframe): # pylint: disable=unused-argument print(" kaboom") self.closekill() sys.exit(0) def closekill(self): if hasattr(self, "auth") and self.auth is not None: if self.auth.is_alive(): self.log.debug("trying to stop auth thread") self.auth.stop() self.log.debug("auth thread seems to have stopped") del self.auth if hasattr(self, "_who_are_you_thread"): if self._who_are_you_thread.is_alive(): self.log.debug("WAI Thread seems to be alive, trying to join") self._who_are_you_continue = False self._who_are_you_thread.join() self.log.debug("WAI Thread seems to jave joined us.") del self._who_are_you_thread if hasattr(self, "cleartext_ctx"): self.log.debug("destroying cleartext context") self.cleartext_ctx.destroy(1) del self.cleartext_ctx if hasattr(self, "ctx"): self.log.debug("destroying crypto context") self.ctx.destroy(1) del self.ctx def __del__(self): self.log.debug("%s is being deleted", self) self.closekill() def bind(self, socket, enable_curve=True): if enable_curve: socket.curve_server = True # must come before bind try: f = self.endpoint.format(socket.type) socket.bind(f) except zmq.ZMQError as e: raise zmq.ZMQError(f"unable to bind {f}: {e}") from e def who_are_you_request(self, endpoint): req = self.mk_socket(zmq.REQ, enable_curve=False) req.connect(endpoint.format(zmq.REQ)) msg = b"Who are you?" self.log.debug("sending cleartext request: %s", msg) req.send(msg) self.log.debug("waiting for reply") res = req.recv_multipart() self.log.debug("received reply: %s", res) if len(res) == 2: return res req.close() return None, None def pubkey_pathname(self, node_id): if isinstance(node_id, Endpoint): node_id = Endpoint.host fname = scrub_identity_name_for_certfile(node_id) + ".key" pname = os.path.join(self.keyring, fname) return pname def learn_or_load_endpoint_pubkey(self, endpoint): epubk_pname = self.pubkey_pathname(endpoint) if not os.path.isfile(epubk_pname): self.log.debug( "%s does not exist yet, trying to learn certificate", epubk_pname) node_id, public_key = self.who_are_you_request(endpoint) if node_id: endpoint.identity = node_id.decode() epubk_pname = self.pubkey_pathname(node_id) if not os.path.isfile(epubk_pname): with open(epubk_pname, "wb") as fh: fh.write( b"# generated via rep/req pubkey transfer\n\n") fh.write(b"metadata\n") # NOTE: in zmq/auth/certs.py's _write_key_file, # metadata should be key-value pairs; roughly like the # following (although with their particular py2/py3 # nerosis edited out): # # f.write('metadata\n') # for k,v in metadata.items(): # f.write(f" {k} = {v}\n") fh.write(b"curve\n") fh.write(b' public-key = "') fh.write(public_key) fh.write(b'"') self.log.debug("loading certificate %s", epubk_pname) ret, _ = zmq.auth.load_certificate(epubk_pname) return ret def connect_to_endpoints(self, *endpoints): self.log.debug("connecting remote endpoints") for item in endpoints: self.connect_to_endpoint(item) self.log.debug("remote endpoints connected") return self def _create_connected_socket(self, endpoint, stype, pubkey, preconnect=None): self.log.debug("creating %s socket to endpoint=%s", zmq_socket_type_name(stype), endpoint) s = self.mk_socket(stype) s.curve_serverkey = pubkey if callable(preconnect): preconnect(s) s.connect(endpoint.format(stype)) return s def connect_to_endpoint(self, endpoint): if isinstance(endpoint, StupidNode): endpoint = endpoint.endpoint elif not isinstance(endpoint, Endpoint): endpoint = Endpoint(endpoint) self.log.debug("learning or loading endpoint=%s pubkey", endpoint) epk = self.learn_or_load_endpoint_pubkey(endpoint) sos = lambda s: s.setsockopt_string(zmq.SUBSCRIBE, self.channel) sub = self._create_connected_socket(endpoint, zmq.SUB, epk, sos) self.poller.register(sub, zmq.POLLIN) self.sub.append(sub) deal = self._create_connected_socket(endpoint, zmq.DEALER, epk) self.poller.register(deal, zmq.POLLIN) self.dealer.append(deal) self.endpoints.append(endpoint) return self def __repr__(self): return f"{self.__class__.__name__}({self.identity})"