class AnsibleKernelHelpersThread(object): def __init__(self, queue): self.queue = queue self.io_loop = IOLoop(make_current=False) context = zmq.Context.instance() self.pause_socket = context.socket(zmq.REP) self.pause_socket_port = self.pause_socket.bind_to_random_port( "tcp://127.0.0.1") self.status_socket = context.socket(zmq.PULL) self.status_socket_port = self.status_socket.bind_to_random_port( "tcp://127.0.0.1") self.pause_stream = ZMQStream(self.pause_socket, self.io_loop) self.status_stream = ZMQStream(self.status_socket, self.io_loop) self.pause_stream.on_recv(self.recv_pause) self.status_stream.on_recv(self.recv_status) self.thread = threading.Thread(target=self._thread_main) self.thread.daemon = True def start(self): logger.info('thread.start') self.thread.start() atexit.register(self.stop) def stop(self): logger.info('thread.stop start') if not self.thread.is_alive(): return self.io_loop.add_callback(self.io_loop.stop) self.thread.join() logger.info('thread.stop end') def recv_status(self, msg): logger = logging.getLogger('ansible_kernel.kernel.recv_status') logger.info(msg) self.queue.put(StatusMessage(json.loads(msg[0]))) def recv_pause(self, msg): logger = logging.getLogger('ansible_kernel.kernel.recv_pause') logger.info("completed %s waiting...", msg) self.queue.put(TaskCompletionMessage(json.loads(msg[0]))) def _thread_main(self): """The inner loop that's actually run in a thread""" self.io_loop.make_current() self.io_loop.start() self.io_loop.close(all_fds=True)
class TestFutureSocket(BaseZMQTestCase): Context = future.Context def setUp(self): self.loop = IOLoop() self.loop.make_current() super(TestFutureSocket, self).setUp() def tearDown(self): super(TestFutureSocket, self).tearDown() self.loop.close(all_fds=True) def test_socket_class(self): s = self.context.socket(zmq.PUSH) assert isinstance(s, future.Socket) s.close() def test_recv_multipart(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() assert not f.done() yield a.send(b'hi') recvd = yield f self.assertEqual(recvd, [b'hi']) self.loop.run_sync(test) def test_recv(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() assert not f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.done() self.assertEqual(f1.result(), b'hi') self.assertEqual(recvd, b'there') self.loop.run_sync(test) def test_recv_cancel(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv_multipart() assert f1.cancel() assert f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.cancelled() assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) def test_poll(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.poll(timeout=0) self.assertEqual(f.result(), 0) f = b.poll(timeout=1) assert not f.done() evt = yield f self.assertEqual(evt, 0) f = b.poll(timeout=1000) assert not f.done() yield a.send_multipart([b'hi', b'there']) evt = yield f self.assertEqual(evt, zmq.POLLIN) recvd = yield b.recv_multipart() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test)
class TestFutureSocket(BaseZMQTestCase): Context = future.Context def setUp(self): self.loop = IOLoop() self.loop.make_current() super(TestFutureSocket, self).setUp() def tearDown(self): super(TestFutureSocket, self).tearDown() self.loop.close(all_fds=True) def test_socket_class(self): s = self.context.socket(zmq.PUSH) assert isinstance(s, future.Socket) s.close() def test_recv_multipart(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() assert not f.done() yield a.send(b'hi') recvd = yield f self.assertEqual(recvd, [b'hi']) self.loop.run_sync(test) def test_recv(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() assert not f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.done() self.assertEqual(f1.result(), b'hi') self.assertEqual(recvd, b'there') self.loop.run_sync(test) def test_recv_cancel(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv_multipart() assert f1.cancel() assert f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.cancelled() assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO") def test_recv_timeout(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) b.rcvtimeo = 100 f1 = b.recv() b.rcvtimeo = 1000 f2 = b.recv_multipart() with pytest.raises(zmq.Again): yield f1 yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO") def test_send_timeout(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) s.sndtimeo = 100 with pytest.raises(zmq.Again): yield s.send(b'not going anywhere') self.loop.run_sync(test) def test_recv_string(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_string() assert not f.done() msg = u('πøøπ') yield a.send_string(msg) recvd = yield f assert f.done() self.assertEqual(f.result(), msg) self.assertEqual(recvd, msg) self.loop.run_sync(test) def test_recv_json(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_json() assert not f.done() obj = dict(a=5) yield a.send_json(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_recv_pyobj(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_pyobj() assert not f.done() obj = dict(a=5) yield a.send_pyobj(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_poll(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.poll(timeout=0) self.assertEqual(f.result(), 0) f = b.poll(timeout=1) assert not f.done() evt = yield f self.assertEqual(evt, 0) f = b.poll(timeout=1000) assert not f.done() yield a.send_multipart([b'hi', b'there']) evt = yield f self.assertEqual(evt, zmq.POLLIN) recvd = yield b.recv_multipart() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test)
class TestFutureSocket(BaseZMQTestCase): Context = future.Context def setUp(self): self.loop = IOLoop() self.loop.make_current() super(TestFutureSocket, self).setUp() def tearDown(self): super(TestFutureSocket, self).tearDown() if self.loop: self.loop.close(all_fds=True) def test_socket_class(self): s = self.context.socket(zmq.PUSH) assert isinstance(s, future.Socket) s.close() def test_recv_multipart(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() assert not f.done() yield a.send(b'hi') recvd = yield f self.assertEqual(recvd, [b'hi']) self.loop.run_sync(test) def test_recv(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() assert not f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.done() self.assertEqual(f1.result(), b'hi') self.assertEqual(recvd, b'there') self.loop.run_sync(test) def test_recv_cancel(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv_multipart() assert f1.cancel() assert f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.cancelled() assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO") def test_recv_timeout(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) b.rcvtimeo = 100 f1 = b.recv() b.rcvtimeo = 1000 f2 = b.recv_multipart() with pytest.raises(zmq.Again): yield f1 yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO") def test_send_timeout(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) s.sndtimeo = 100 with pytest.raises(zmq.Again): yield s.send(b'not going anywhere') self.loop.run_sync(test) @pytest.mark.now def test_send_noblock(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) with pytest.raises(zmq.Again): yield s.send(b'not going anywhere', flags=zmq.NOBLOCK) self.loop.run_sync(test) @pytest.mark.now def test_send_multipart_noblock(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) with pytest.raises(zmq.Again): yield s.send_multipart([b'not going anywhere'], flags=zmq.NOBLOCK) self.loop.run_sync(test) def test_recv_string(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_string() assert not f.done() msg = u('πøøπ') yield a.send_string(msg) recvd = yield f assert f.done() self.assertEqual(f.result(), msg) self.assertEqual(recvd, msg) self.loop.run_sync(test) def test_recv_json(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_json() assert not f.done() obj = dict(a=5) yield a.send_json(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_recv_json_cancelled(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_json() assert not f.done() f.cancel() # cycle eventloop to allow cancel events to fire yield gen.sleep(0) obj = dict(a=5) yield a.send_json(obj) with pytest.raises(future.CancelledError): recvd = yield f assert f.done() # give it a chance to incorrectly consume the event events = yield b.poll(timeout=5) assert events yield gen.sleep(0) # make sure cancelled recv didn't eat up event recvd = yield gen.with_timeout(timedelta(seconds=5), b.recv_json()) assert recvd == obj self.loop.run_sync(test) def test_recv_pyobj(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_pyobj() assert not f.done() obj = dict(a=5) yield a.send_pyobj(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_poll(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.poll(timeout=0) self.assertEqual(f.result(), 0) f = b.poll(timeout=1) assert not f.done() evt = yield f self.assertEqual(evt, 0) f = b.poll(timeout=1000) assert not f.done() yield a.send_multipart([b'hi', b'there']) evt = yield f self.assertEqual(evt, zmq.POLLIN) recvd = yield b.recv_multipart() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) def test_close_all_fds(self): s = self.socket(zmq.PUB) self.loop.close(all_fds=True) self.loop = None # avoid second close later assert s.closed def test_poll_raw(self): @gen.coroutine def test(): p = future.Poller() # make a pipe r, w = os.pipe() r = os.fdopen(r, 'rb') w = os.fdopen(w, 'wb') # POLLOUT p.register(r, zmq.POLLIN) p.register(w, zmq.POLLOUT) evts = yield p.poll(timeout=1) evts = dict(evts) assert r.fileno() not in evts assert w.fileno() in evts assert evts[w.fileno()] == zmq.POLLOUT # POLLIN p.unregister(w) w.write(b'x') w.flush() evts = yield p.poll(timeout=1000) evts = dict(evts) assert r.fileno() in evts assert evts[r.fileno()] == zmq.POLLIN assert r.read(1) == b'x' r.close() w.close() self.loop.run_sync(test)
class IOPubThread(object): """An object for sending IOPub messages in a background thread prevents a blocking main thread IOPubThread(pub_socket).background_socket is a Socket-API-providing object whose IO is always run in a thread. """ def __init__(self, socket, pipe=False): self.socket = socket self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_flag = pipe self.io_loop = IOLoop() if pipe: self._setup_pipe_in() self.thread = threading.Thread(target=self._thread_main) self.thread.daemon = True def _thread_main(self): """The inner loop that's actually run in a thread""" self.io_loop.start() self.io_loop.close() def _setup_pipe_in(self): """setup listening pipe for subprocesses""" ctx = self.socket.context # use UUID to authenticate pipe messages self._pipe_uuid = uuid.uuid4().bytes pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 try: self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warnings.warn("Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e + "\nsubprocess output will be unavailable.") self._pipe_flag = False pipe_in.close() return self._pipe_in = ZMQStream(pipe_in, self.io_loop) self._pipe_in.on_recv(self._handle_pipe_msg) def _handle_pipe_msg(self, msg): """handle a pipe message from a subprocess""" if not self._pipe_flag or not self._is_master_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) return self.send_multipart(msg[1:]) def _setup_pipe_out(self): # must be new context after fork ctx = zmq.Context() pipe_out = ctx.socket(zmq.PUSH) pipe_out.linger = 3000 # 3s timeout for pipe_out sends before discarding the message pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) return ctx, pipe_out def _is_master_process(self): return os.getpid() == self._master_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" if not self._pipe_flag or self._is_master_process(): return MASTER else: return CHILD def start(self): """Start the IOPub thread""" self.thread.start() # make sure we don't prevent process exit # I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be. atexit.register(self.stop) def stop(self): """Stop the IOPub thread""" if not self.thread.is_alive(): return self.io_loop.add_callback(self.io_loop.stop) self.thread.join() def close(self): self.socket.close() self.socket = None @property def closed(self): return self.socket is None def send_multipart(self, *args, **kwargs): """send_multipart schedules actual zmq send in my thread. If my thread isn't running (e.g. forked process), send immediately. """ if self.thread.is_alive(): self.io_loop.add_callback( lambda: self._really_send(*args, **kwargs)) else: self._really_send(*args, **kwargs) def _really_send(self, msg, *args, **kwargs): """The callback that actually sends messages""" mp_mode = self._check_mp_mode() if mp_mode != CHILD: # we are master, do a regular send self.socket.send_multipart(msg, *args, **kwargs) else: # we are a child, pipe to master # new context/socket for every pipe-out # since forks don't teardown politely, use ctx.term to ensure send has completed ctx, pipe_out = self._setup_pipe_out() pipe_out.send_multipart([self._pipe_uuid] + msg, *args, **kwargs) pipe_out.close() ctx.term()
class TestFutureSocket(BaseZMQTestCase): Context = future.Context def setUp(self): self.loop = IOLoop() self.loop.make_current() super(TestFutureSocket, self).setUp() def tearDown(self): super(TestFutureSocket, self).tearDown() if self.loop: self.loop.close(all_fds=True) def test_socket_class(self): s = self.context.socket(zmq.PUSH) assert isinstance(s, future.Socket) s.close() def test_recv_multipart(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() assert not f.done() yield a.send(b'hi') recvd = yield f self.assertEqual(recvd, [b'hi']) self.loop.run_sync(test) def test_recv(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() assert not f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.done() self.assertEqual(f1.result(), b'hi') self.assertEqual(recvd, b'there') self.loop.run_sync(test) def test_recv_cancel(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv_multipart() assert f1.cancel() assert f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.cancelled() assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO") def test_recv_timeout(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) b.rcvtimeo = 100 f1 = b.recv() b.rcvtimeo = 1000 f2 = b.recv_multipart() with pytest.raises(zmq.Again): yield f1 yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO") def test_send_timeout(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) s.sndtimeo = 100 with pytest.raises(zmq.Again): yield s.send(b'not going anywhere') self.loop.run_sync(test) @pytest.mark.now def test_send_noblock(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) with pytest.raises(zmq.Again): yield s.send(b'not going anywhere', flags=zmq.NOBLOCK) self.loop.run_sync(test) @pytest.mark.now def test_send_multipart_noblock(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) with pytest.raises(zmq.Again): yield s.send_multipart([b'not going anywhere'], flags=zmq.NOBLOCK) self.loop.run_sync(test) def test_recv_string(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_string() assert not f.done() msg = u('πøøπ') yield a.send_string(msg) recvd = yield f assert f.done() self.assertEqual(f.result(), msg) self.assertEqual(recvd, msg) self.loop.run_sync(test) def test_recv_json(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_json() assert not f.done() obj = dict(a=5) yield a.send_json(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_recv_json_cancelled(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_json() assert not f.done() f.cancel() # cycle eventloop to allow cancel events to fire yield gen.sleep(0) obj = dict(a=5) yield a.send_json(obj) with pytest.raises(future.CancelledError): recvd = yield f assert f.done() # give it a chance to incorrectly consume the event events = yield b.poll(timeout=5) assert events yield gen.sleep(0) # make sure cancelled recv didn't eat up event recvd = yield gen.with_timeout(timedelta(seconds=5), b.recv_json()) assert recvd == obj self.loop.run_sync(test) def test_recv_pyobj(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_pyobj() assert not f.done() obj = dict(a=5) yield a.send_pyobj(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_poll(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.poll(timeout=0) self.assertEqual(f.result(), 0) f = b.poll(timeout=1) assert not f.done() evt = yield f self.assertEqual(evt, 0) f = b.poll(timeout=1000) assert not f.done() yield a.send_multipart([b'hi', b'there']) evt = yield f self.assertEqual(evt, zmq.POLLIN) recvd = yield b.recv_multipart() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) def test_close_all_fds(self): s = self.socket(zmq.PUB) self.loop.close(all_fds=True) self.loop = None # avoid second close later assert s.closed def test_poll_raw(self): @gen.coroutine def test(): p = future.Poller() # make a pipe r, w = os.pipe() r = os.fdopen(r, 'rb') w = os.fdopen(w, 'wb') # POLLOUT p.register(r, zmq.POLLIN) p.register(w, zmq.POLLOUT) evts = yield p.poll(timeout=1) evts = dict(evts) assert r.fileno() not in evts assert w.fileno() in evts assert evts[w.fileno()] == zmq.POLLOUT # POLLIN p.unregister(w) w.write(b'x') w.flush() evts = yield p.poll(timeout=1000) evts = dict(evts) assert r.fileno() in evts assert evts[r.fileno()] == zmq.POLLIN assert r.read(1) == b'x' r.close() w.close() self.loop.run_sync(test)
class IOPubThread(object): """An object for sending IOPub messages in a background thread Prevents a blocking main thread from delaying output from threads. IOPubThread(pub_socket).background_socket is a Socket-API-providing object whose IO is always run in a thread. """ def __init__(self, socket, pipe=False): """Create IOPub thread Parameters ---------- socket: zmq.PUB Socket the socket on which messages will be sent. pipe: bool Whether this process should listen for IOPub messages piped from subprocesses. """ self.socket = socket self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_flag = pipe self.io_loop = IOLoop(make_current=False) if pipe: self._setup_pipe_in() self._local = threading.local() self._events = deque() self._setup_event_pipe() self.thread = threading.Thread(target=self._thread_main) self.thread.daemon = True def _thread_main(self): """The inner loop that's actually run in a thread""" self.io_loop.make_current() self.io_loop.start() self.io_loop.close(all_fds=True) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" ctx = self.socket.context pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode('ascii') iface = self._event_interface = 'inproc://%s' % _uuid pipe_in.bind(iface) self._event_puller = ZMQStream(pipe_in, self.io_loop) self._event_puller.on_recv(self._handle_event) @property def _event_pipe(self): """thread-local event pipe for signaling events that should be processed in the thread""" try: event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe ctx = self.socket.context event_pipe = ctx.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe return event_pipe def _handle_event(self, msg): """Handle an event on the event pipe Content of the message is ignored. Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ # freeze event count so new writes don't extend the queue # while we are processing n_events = len(self._events) for i in range(n_events): event_f = self._events.popleft() event_f() def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" ctx = self.socket.context # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 try: self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warnings.warn("Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e + "\nsubprocess output will be unavailable.") self._pipe_flag = False pipe_in.close() return self._pipe_in = ZMQStream(pipe_in, self.io_loop) self._pipe_in.on_recv(self._handle_pipe_msg) def _handle_pipe_msg(self, msg): """handle a pipe message from a subprocess""" if not self._pipe_flag or not self._is_master_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) return self.send_multipart(msg[1:]) def _setup_pipe_out(self): # must be new context after fork ctx = zmq.Context() pipe_out = ctx.socket(zmq.PUSH) pipe_out.linger = 3000 # 3s timeout for pipe_out sends before discarding the message pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) return ctx, pipe_out def _is_master_process(self): return os.getpid() == self._master_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" if not self._pipe_flag or self._is_master_process(): return MASTER else: return CHILD def start(self): """Start the IOPub thread""" self.thread.start() # make sure we don't prevent process exit # I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be. atexit.register(self.stop) def stop(self): """Stop the IOPub thread""" if not self.thread.is_alive(): return self.io_loop.add_callback(self.io_loop.stop) self.thread.join() if hasattr(self._local, 'event_pipe'): self._local.event_pipe.close() def close(self): self.socket.close() self.socket = None @property def closed(self): return self.socket is None def schedule(self, f): """Schedule a function to be called in our IO thread. If the thread is not running, call immediately. """ if self.thread.is_alive(): self._events.append(f) # wake event thread (message content is ignored) self._event_pipe.send(b'') else: f() def send_multipart(self, *args, **kwargs): """send_multipart schedules actual zmq send in my thread. If my thread isn't running (e.g. forked process), send immediately. """ self.schedule(lambda: self._really_send(*args, **kwargs)) def _really_send(self, msg, *args, **kwargs): """The callback that actually sends messages""" mp_mode = self._check_mp_mode() if mp_mode != CHILD: # we are master, do a regular send self.socket.send_multipart(msg, *args, **kwargs) else: # we are a child, pipe to master # new context/socket for every pipe-out # since forks don't teardown politely, use ctx.term to ensure send has completed ctx, pipe_out = self._setup_pipe_out() pipe_out.send_multipart([self._pipe_uuid] + msg, *args, **kwargs) pipe_out.close() ctx.term()
class ZMQ(Client): def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): # do our best to clean up potentially leaky FDs if hasattr(self, 'stream') and not self.stream.closed(): self.stream.close() if hasattr(self, 'loop'): try: self.loop.close() except KeyError: pass self.socket.close() self.context.term() def __init__(self, remote, token, **kwargs): super(ZMQ, self).__init__(remote, token) self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) self.socket.RCVTIMEO = RCVTIMEO self.socket.SNDTIMEO = SNDTIMEO self.socket.setsockopt(zmq.LINGER, LINGER) self.nowait = kwargs.get('nowait', False) self.autoclose = kwargs.get('autoclose', True) if self.nowait: self.socket = self.context.socket(zmq.DEALER) self.autoclose = kwargs.get('autoclose', True) def _handle_message_fireball(self, m): logger.debug('message received') m = json.loads(m[2].decode('utf-8')) self.response.append(m) self.num_responses -= 1 logger.debug('num responses remaining: %i' % self.num_responses) if self.num_responses == 0: logger.debug('finishing up...') self.loop.stop() def _fireball_timeout(self): logger.info('fireball timeout') self.loop.stop() def _send_fireball(self, mtype, data, f_size): if len(data) < 3: logger.error('no data to send') return [] self.loop = IOLoop().instance() self.socket.close() self.socket = self.context.socket(zmq.DEALER) self.socket.connect(self.remote) self.stream = ZMQStream(self.socket) self.stream.on_recv(self._handle_message_fireball) self.stream.io_loop.call_later(SNDTIMEO, self._fireball_timeout) self.response = [] if PYVERSION == 3: if isinstance(data, bytes): data = data.decode('utf-8') data = json.loads(data) if not isinstance(data, list): data = [data] if (len(data) % f_size) == 0: self.num_responses = int((len(data) / f_size)) else: self.num_responses = int((len(data) / f_size)) + 1 logger.debug('responses expected: %i' % self.num_responses) batch = [] for d in data: batch.append(d) if len(batch) == f_size: Msg(mtype=Msg.INDICATORS_CREATE, token=self.token, data=batch).send(self.socket) batch = [] if len(batch): Msg(mtype=Msg.INDICATORS_CREATE, token=self.token, data=batch).send(self.socket) logger.debug("starting loop to receive") self.loop.start() # clean up FDs self.loop.close() self.stream.close() self.socket.close() return self.response def _recv(self, decode=True, close=True): mtype, data = Msg().recv(self.socket) if close: self.socket.close() if not decode: return data data = json.loads(data) if data.get('message') == 'unauthorized': raise AuthError() if data.get('message') == 'busy': raise CIFBusy() if data.get('message') == 'invalid search': raise InvalidSearch() if data.get('status') != 'success': raise RuntimeError(data.get('message')) if data.get('data') is None: raise RuntimeError('invalid response') if isinstance(data.get('data'), bool): return data['data'] # is this a straight up elasticsearch string? if data['data'] == '{}': return [] if isinstance(data['data'], basestring) and data['data'].startswith('{"hits":{"hits":[{"_source":'): data['data'] = json.loads(data['data']) data['data'] = [r['_source'] for r in data['data']['hits']['hits']] try: data['data'] = zlib.decompress(data['data']) except (zlib.error, TypeError): pass return data.get('data') def _send(self, mtype, data='[]', nowait=False, decode=True): self.socket.connect(self.remote) if isinstance(data, str): data = data.encode('utf-8') Msg(mtype=mtype, token=self.token, data=data).send(self.socket) if self.nowait or nowait: if self.autoclose: self.socket.close() return rv = self._recv(decode=decode) return rv def ping(self): try: return self._send(Msg.PING) except zmq.error.Again: raise TimeoutError def ping_write(self): try: return self._send(Msg.PING_WRITE) except zmq.error.Again: raise TimeoutError def indicators_search(self, filters, decode=True): return self._send(Msg.INDICATORS_SEARCH, json.dumps(filters), decode=decode) def graph_search(self, filters, decode=True): return self._send(Msg.GRAPH_SEARCH, json.dumps(filters), decode=decode) def stats_search(self, filters, decode=True): return self._send(Msg.STATS_SEARCH, json.dumps(filters), decode=decode) def indicators_create(self, data, nowait=False, fireball=False, f_size=FIREBALL_SIZE): if isinstance(data, dict): data = self._kv_to_indicator(data) if isinstance(data, Indicator): data = str(data) if fireball: return self._send_fireball(Msg.INDICATORS_CREATE, data, f_size) return self._send(Msg.INDICATORS_CREATE, data, nowait=nowait) def indicators_delete(self, data): if isinstance(data, dict): data = self._kv_to_indicator(data) if isinstance(data, Indicator): data = str(data) return self._send(Msg.INDICATORS_DELETE, data) def tokens_search(self, filters={}): return self._send(Msg.TOKENS_SEARCH, json.dumps(filters)) def tokens_create(self, data): return self._send(Msg.TOKENS_CREATE, data) def tokens_delete(self, data): return self._send(Msg.TOKENS_DELETE, data) def tokens_edit(self, data): return self._send(Msg.TOKENS_EDIT, data)
class IOPubThread(object): """An object for sending IOPub messages in a background thread Prevents a blocking main thread from delaying output from threads. IOPubThread(pub_socket).background_socket is a Socket-API-providing object whose IO is always run in a thread. """ def __init__(self, socket, pipe=False): """Create IOPub thread Parameters ---------- socket: zmq.PUB Socket the socket on which messages will be sent. pipe: bool Whether this process should listen for IOPub messages piped from subprocesses. """ self.socket = socket self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_flag = pipe self.io_loop = IOLoop(make_current=False) if pipe: self._setup_pipe_in() self._local = threading.local() self._events = deque() self._setup_event_pipe() self.thread = threading.Thread(target=self._thread_main) self.thread.daemon = True def _thread_main(self): """The inner loop that's actually run in a thread""" self.io_loop.make_current() self.io_loop.start() self.io_loop.close(all_fds=True) def _setup_event_pipe(self): """Create the PULL socket listening for events that should fire in this thread.""" ctx = self.socket.context pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 _uuid = b2a_hex(os.urandom(16)).decode('ascii') iface = self._event_interface = 'inproc://%s' % _uuid pipe_in.bind(iface) self._event_puller = ZMQStream(pipe_in, self.io_loop) self._event_puller.on_recv(self._handle_event) @property def _event_pipe(self): """thread-local event pipe for signaling events that should be processed in the thread""" try: event_pipe = self._local.event_pipe except AttributeError: # new thread, new event pipe ctx = self.socket.context event_pipe = ctx.socket(zmq.PUSH) event_pipe.linger = 0 event_pipe.connect(self._event_interface) self._local.event_pipe = event_pipe return event_pipe def _handle_event(self, msg): """Handle an event on the event pipe Content of the message is ignored. Whenever *an* event arrives on the event stream, *all* waiting events are processed in order. """ # freeze event count so new writes don't extend the queue # while we are processing n_events = len(self._events) for i in range(n_events): event_f = self._events.popleft() event_f() def _setup_pipe_in(self): """setup listening pipe for IOPub from forked subprocesses""" ctx = self.socket.context # use UUID to authenticate pipe messages self._pipe_uuid = os.urandom(16) pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 try: self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warnings.warn("Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e + "\nsubprocess output will be unavailable." ) self._pipe_flag = False pipe_in.close() return self._pipe_in = ZMQStream(pipe_in, self.io_loop) self._pipe_in.on_recv(self._handle_pipe_msg) def _handle_pipe_msg(self, msg): """handle a pipe message from a subprocess""" if not self._pipe_flag or not self._is_master_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) return self.send_multipart(msg[1:]) def _setup_pipe_out(self): # must be new context after fork ctx = zmq.Context() pipe_out = ctx.socket(zmq.PUSH) pipe_out.linger = 3000 # 3s timeout for pipe_out sends before discarding the message pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) return ctx, pipe_out def _is_master_process(self): return os.getpid() == self._master_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" if not self._pipe_flag or self._is_master_process(): return MASTER else: return CHILD def start(self): """Start the IOPub thread""" self.thread.start() # make sure we don't prevent process exit # I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be. atexit.register(self.stop) def stop(self): """Stop the IOPub thread""" if not self.thread.is_alive(): return self.io_loop.add_callback(self.io_loop.stop) self.thread.join() if hasattr(self._local, 'event_pipe'): self._local.event_pipe.close() def close(self): self.socket.close() self.socket = None @property def closed(self): return self.socket is None def schedule(self, f): """Schedule a function to be called in our IO thread. If the thread is not running, call immediately. """ if self.thread.is_alive(): self._events.append(f) # wake event thread (message content is ignored) self._event_pipe.send(b'') else: f() def send_multipart(self, *args, **kwargs): """send_multipart schedules actual zmq send in my thread. If my thread isn't running (e.g. forked process), send immediately. """ self.schedule(lambda : self._really_send(*args, **kwargs)) def _really_send(self, msg, *args, **kwargs): """The callback that actually sends messages""" mp_mode = self._check_mp_mode() if mp_mode != CHILD: # we are master, do a regular send self.socket.send_multipart(msg, *args, **kwargs) else: # we are a child, pipe to master # new context/socket for every pipe-out # since forks don't teardown politely, use ctx.term to ensure send has completed ctx, pipe_out = self._setup_pipe_out() pipe_out.send_multipart([self._pipe_uuid] + msg, *args, **kwargs) pipe_out.close() ctx.term()
class IOPubThread(object): """An object for sending IOPub messages in a background thread prevents a blocking main thread IOPubThread(pub_socket).background_socket is a Socket-API-providing object whose IO is always run in a thread. """ def __init__(self, socket, pipe=False): self.socket = socket self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_pid = os.getpid() self._pipe_flag = pipe self.io_loop = IOLoop() if pipe: self._setup_pipe_in() self.thread = threading.Thread(target=self._thread_main) self.thread.daemon = True def _thread_main(self): """The inner loop that's actually run in a thread""" self.io_loop.start() self.io_loop.close() def _setup_pipe_in(self): """setup listening pipe for subprocesses""" ctx = self.socket.context # use UUID to authenticate pipe messages self._pipe_uuid = uuid.uuid4().bytes pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 try: self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warn("Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e + "\nsubprocess output will be unavailable." ) self._pipe_flag = False pipe_in.close() return self._pipe_in = ZMQStream(pipe_in, self.io_loop) self._pipe_in.on_recv(self._handle_pipe_msg) def _handle_pipe_msg(self, msg): """handle a pipe message from a subprocess""" if not self._pipe_flag or not self._is_master_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) return self.send_multipart(msg[1:]) def _setup_pipe_out(self): # must be new context after fork ctx = zmq.Context() self._pipe_pid = os.getpid() self._pipe_out = ctx.socket(zmq.PUSH) self._pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) def _is_master_process(self): return os.getpid() == self._master_pid def _have_pipe_out(self): return os.getpid() == self._pipe_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" if not self._pipe_flag or self._is_master_process(): return MASTER else: if not self._have_pipe_out(): # setup a new out pipe self._setup_pipe_out() return CHILD def start(self): """Start the IOPub thread""" self.thread.start() def stop(self): """Stop the IOPub thread""" self.io_loop.add_callback(self.io_loop.stop) self.thread.join() def close(self): self.socket.close() self.socket = None @property def closed(self): return self.socket is None def send_multipart(self, *args, **kwargs): """send_multipart schedules actual zmq send in my thread. If my thread isn't running (e.g. forked process), send immediately. """ if self.thread.is_alive(): self.io_loop.add_callback(lambda : self._really_send(*args, **kwargs)) else: self._really_send(*args, **kwargs) def _really_send(self, msg, *args, **kwargs): """The callback that actually sends messages""" mp_mode = self._check_mp_mode() if mp_mode != CHILD: # we are master, do a regular send self.socket.send_multipart(msg, *args, **kwargs) else: # we are a child, pipe to master kwargs['copy'] = False kwargs['track'] = True tracker = self._pipe_out.send_multipart([self._pipe_uuid] + msg, *args, **kwargs) try: tracker.wait(1) except Exception as e: print("Failed to send: %s" % e, file=sys.__stderr__) pass
class TestFutureSocket(BaseZMQTestCase): Context = future.Context def setUp(self): self.loop = IOLoop() self.loop.make_current() super(TestFutureSocket, self).setUp() def tearDown(self): super(TestFutureSocket, self).tearDown() self.loop.close(all_fds=True) def test_socket_class(self): s = self.context.socket(zmq.PUSH) assert isinstance(s, future.Socket) s.close() def test_recv_multipart(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() assert not f.done() yield a.send(b'hi') recvd = yield f self.assertEqual(recvd, [b'hi']) self.loop.run_sync(test) def test_recv(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() assert not f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.done() self.assertEqual(f1.result(), b'hi') self.assertEqual(recvd, b'there') self.loop.run_sync(test) def test_recv_cancel(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv_multipart() assert f1.cancel() assert f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.cancelled() assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) def test_poll(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.poll(timeout=0) self.assertEqual(f.result(), 0) f = b.poll(timeout=1) assert not f.done() evt = yield f self.assertEqual(evt, 0) f = b.poll(timeout=1000) assert not f.done() yield a.send_multipart([b'hi', b'there']) evt = yield f self.assertEqual(evt, zmq.POLLIN) recvd = yield b.recv_multipart() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test)
class IOPubThread(object): """An object for sending IOPub messages in a background thread prevents a blocking main thread IOPubThread(pub_socket).background_socket is a Socket-API-providing object whose IO is always run in a thread. """ def __init__(self, socket, pipe=False): self.socket = socket self.background_socket = BackgroundSocket(self) self._master_pid = os.getpid() self._pipe_flag = pipe self.io_loop = IOLoop() if pipe: self._setup_pipe_in() self.thread = threading.Thread(target=self._thread_main) self.thread.daemon = True def _thread_main(self): """The inner loop that's actually run in a thread""" self.io_loop.start() self.io_loop.close() def _setup_pipe_in(self): """setup listening pipe for subprocesses""" ctx = self.socket.context # use UUID to authenticate pipe messages self._pipe_uuid = uuid.uuid4().bytes pipe_in = ctx.socket(zmq.PULL) pipe_in.linger = 0 try: self._pipe_port = pipe_in.bind_to_random_port("tcp://127.0.0.1") except zmq.ZMQError as e: warnings.warn("Couldn't bind IOPub Pipe to 127.0.0.1: %s" % e + "\nsubprocess output will be unavailable." ) self._pipe_flag = False pipe_in.close() return self._pipe_in = ZMQStream(pipe_in, self.io_loop) self._pipe_in.on_recv(self._handle_pipe_msg) def _handle_pipe_msg(self, msg): """handle a pipe message from a subprocess""" if not self._pipe_flag or not self._is_master_process(): return if msg[0] != self._pipe_uuid: print("Bad pipe message: %s", msg, file=sys.__stderr__) return self.send_multipart(msg[1:]) def _setup_pipe_out(self): # must be new context after fork ctx = zmq.Context() pipe_out = ctx.socket(zmq.PUSH) pipe_out.linger = 3000 # 3s timeout for pipe_out sends before discarding the message pipe_out.connect("tcp://127.0.0.1:%i" % self._pipe_port) return ctx, pipe_out def _is_master_process(self): return os.getpid() == self._master_pid def _check_mp_mode(self): """check for forks, and switch to zmq pipeline if necessary""" if not self._pipe_flag or self._is_master_process(): return MASTER else: return CHILD def start(self): """Start the IOPub thread""" self.thread.start() # make sure we don't prevent process exit # I'm not sure why setting daemon=True above isn't enough, but it doesn't appear to be. atexit.register(self.stop) def stop(self): """Stop the IOPub thread""" if not self.thread.is_alive(): return self.io_loop.add_callback(self.io_loop.stop) self.thread.join() def close(self): self.socket.close() self.socket = None @property def closed(self): return self.socket is None def send_multipart(self, *args, **kwargs): """send_multipart schedules actual zmq send in my thread. If my thread isn't running (e.g. forked process), send immediately. """ if self.thread.is_alive(): self.io_loop.add_callback(lambda : self._really_send(*args, **kwargs)) else: self._really_send(*args, **kwargs) def _really_send(self, msg, *args, **kwargs): """The callback that actually sends messages""" mp_mode = self._check_mp_mode() if mp_mode != CHILD: # we are master, do a regular send self.socket.send_multipart(msg, *args, **kwargs) else: # we are a child, pipe to master # new context/socket for every pipe-out # since forks don't teardown politely, use ctx.term to ensure send has completed ctx, pipe_out = self._setup_pipe_out() pipe_out.send_multipart([self._pipe_uuid] + msg, *args, **kwargs) pipe_out.close() ctx.term()
class TestFutureSocket(BaseZMQTestCase): Context = future.Context def setUp(self): self.loop = IOLoop() self.loop.make_current() super(TestFutureSocket, self).setUp() def tearDown(self): super(TestFutureSocket, self).tearDown() self.loop.close(all_fds=True) def test_socket_class(self): s = self.context.socket(zmq.PUSH) assert isinstance(s, future.Socket) s.close() def test_recv_multipart(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_multipart() assert not f.done() yield a.send(b'hi') recvd = yield f self.assertEqual(recvd, [b'hi']) self.loop.run_sync(test) def test_recv(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv() assert not f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.done() self.assertEqual(f1.result(), b'hi') self.assertEqual(recvd, b'there') self.loop.run_sync(test) def test_recv_cancel(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f1 = b.recv() f2 = b.recv_multipart() assert f1.cancel() assert f1.done() assert not f2.done() yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f1.cancelled() assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO") def test_recv_timeout(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) b.rcvtimeo = 100 f1 = b.recv() b.rcvtimeo = 1000 f2 = b.recv_multipart() with pytest.raises(zmq.Again): yield f1 yield a.send_multipart([b'hi', b'there']) recvd = yield f2 assert f2.done() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test) @pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO") def test_send_timeout(self): @gen.coroutine def test(): s = self.socket(zmq.PUSH) s.sndtimeo = 100 with pytest.raises(zmq.Again): yield s.send(b'not going anywhere') self.loop.run_sync(test) def test_recv_string(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_string() assert not f.done() msg = u('πøøπ') yield a.send_string(msg) recvd = yield f assert f.done() self.assertEqual(f.result(), msg) self.assertEqual(recvd, msg) self.loop.run_sync(test) def test_recv_json(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_json() assert not f.done() obj = dict(a=5) yield a.send_json(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_recv_pyobj(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.recv_pyobj() assert not f.done() obj = dict(a=5) yield a.send_pyobj(obj) recvd = yield f assert f.done() self.assertEqual(f.result(), obj) self.assertEqual(recvd, obj) self.loop.run_sync(test) def test_poll(self): @gen.coroutine def test(): a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL) f = b.poll(timeout=0) self.assertEqual(f.result(), 0) f = b.poll(timeout=1) assert not f.done() evt = yield f self.assertEqual(evt, 0) f = b.poll(timeout=1000) assert not f.done() yield a.send_multipart([b'hi', b'there']) evt = yield f self.assertEqual(evt, zmq.POLLIN) recvd = yield b.recv_multipart() self.assertEqual(recvd, [b'hi', b'there']) self.loop.run_sync(test)
class ZMQ(Client): def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.socket.close() self.context.term() def __init__(self, remote, token, **kwargs): super(ZMQ, self).__init__(remote, token) self.context = zmq.Context.instance() self.socket = self.context.socket(zmq.REQ) self.socket.RCVTIMEO = RCVTIMEO self.socket.SNDTIMEO = SNDTIMEO self.socket.setsockopt(zmq.LINGER, LINGER) self.autoclose = kwargs.get('autoclose', True) self.nowait = kwargs.get('nowait', False) if self.nowait: self.socket = self.context.socket(zmq.DEALER) def _recv(self, decode=True): mtype, data = Msg().recv(self.socket) if not decode: return data data = json.loads(data) if data.get('message') == 'unauthorized': raise AuthError() if data.get('message') == 'busy': raise CIFBusy() if data.get('message') == 'invalid search': raise InvalidSearch() if data.get('status') != 'success': raise RuntimeError(data.get('message')) if data.get('data') is None: raise RuntimeError('invalid response') if isinstance(data.get('data'), bool): return data['data'] # is this a straight up elasticsearch string? if data['data'] == '{}': return [] if isinstance(data['data'], basestring) and data['data'].startswith( '{"hits":{"hits":[{"_source":'): data['data'] = json.loads(data['data']) data['data'] = [r['_source'] for r in data['data']['hits']['hits']] try: data['data'] = zlib.decompress(data['data']) except (zlib.error, TypeError): pass return data.get('data') def _send(self, mtype, data='[]', nowait=False, decode=True): self.socket.connect(self.remote) if isinstance(data, str): data = data.encode('utf-8') Msg(mtype=mtype, token=self.token, data=data).send(self.socket) if self.nowait or nowait: if self.autoclose: self.socket.close() logger.debug('not waiting for a resp') return rv = self._recv(decode=decode) self.socket.close() return rv def ping(self, write=False): if write: return self._send(Msg.PING_WRITE) return self._send(Msg.PING) def tokens_search(self, filters={}): return self._send(Msg.TOKENS_SEARCH, json.dumps(filters)) def tokens_create(self, data): return self._send(Msg.TOKENS_CREATE, data) def tokens_delete(self, data): return self._send(Msg.TOKENS_DELETE, data) def tokens_edit(self, data): return self._send(Msg.TOKENS_EDIT, data) def _handle_message_fireball(self, s, e): logger.debug('message received') id, mtype, data = Msg().recv(s) self.response.append(data) self.num_responses -= 1 logger.debug('num responses remaining: %i' % self.num_responses) if self.num_responses == 0: logger.debug('finishing up...') self.loop.stop() def _fireball_timeout(self): logger.warn('timeout') self.loop.stop() raise TimeoutError('timeout') def _send_fireball(self, mtype, data): if len(data) < 3: logger.error('no data to send') return [] self.loop = IOLoop() self.socket = self.context.socket(zmq.DEALER) self.socket.connect(self.remote) timeout = time.time() + SNDTIMEO self.loop.add_timeout(timeout, self._fireball_timeout) self.response = [] self.loop.add_handler(self.socket, self._handle_message_fireball, zmq.POLLIN) if PYVERSION == 3: if isinstance(data, bytes): data = data.decode('utf-8') data = json.loads(data) if not isinstance(data, list): data = [data] if (len(data) % FIREBALL_SIZE) == 0: self.num_responses = int((len(data) / FIREBALL_SIZE)) else: self.num_responses = int((len(data) / FIREBALL_SIZE)) + 1 logger.debug('responses expected: %i' % self.num_responses) batch = [] for d in data: batch.append(d) if len(batch) == FIREBALL_SIZE: Msg(mtype=Msg.INDICATORS_CREATE, token=self.token, data=batch).send(self.socket) batch = [] if len(batch): Msg(mtype=Msg.INDICATORS_CREATE, token=self.token, data=batch).send(self.socket) logger.debug("starting loop to receive") self.loop.start() self.loop.close() self.socket.close() return self.response def indicators_search(self, filters, decode=True): return self._send(Msg.INDICATORS_SEARCH, json.dumps(filters), decode=decode) def indicators_create(self, data, nowait=False, fireball=False): if isinstance(data, dict): data = self._kv_to_indicator(data) if isinstance(data, Indicator): data = str(data) if fireball: logger.info('using fireball mode') return self._send_fireball(Msg.INDICATORS_CREATE, data) return self._send(Msg.INDICATORS_CREATE, data, nowait=nowait) def indicators_delete(self, data): if isinstance(data, dict): data = self._kv_to_indicator(data) if isinstance(data, Indicator): data = str(data) return self._send(Msg.INDICATORS_DELETE, data)