def test_default_version(): s = Session() msg = s.msg("msg_type") msg['header'].pop('version') original = copy.deepcopy(msg) adapted = adapt(original) nt.assert_equal(adapted['header']['version'], V4toV5.version)
def test_deserialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [memoryview(os.urandom(2)) for i in range(3)] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) nt.assert_equal(msg2, msg)
def test_default_version(): s = Session() msg = s.msg("msg_type") msg["header"].pop("version") original = copy.deepcopy(msg) adapted = adapt(original) nt.assert_equal(adapted["header"]["version"], V4toV5.version)
def test_deserialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ os.urandom(2) for i in range(3) ] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) nt.assert_equal(msg2, msg)
def main(connection_file): """watch iopub channel, and print messages""" ctx = zmq.Context.instance() with open(connection_file) as f: cfg = json.loads(f.read()) location = cfg['location'] reg_url = cfg['url'] session = Session(key=str_to_bytes(cfg['exec_key'])) query = ctx.socket(zmq.DEALER) query.connect(disambiguate_url(cfg['url'], location)) session.send(query, "connection_request") idents,msg = session.recv(query, mode=0) c = msg['content'] iopub_url = disambiguate_url(c['iopub'], location) sub = ctx.socket(zmq.SUB) # This will subscribe to all messages: sub.setsockopt(zmq.SUBSCRIBE, b'') # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes # to everything from engine 1, but there is no way to subscribe to # just stdout from everyone. # multiple calls to subscribe will add subscriptions, e.g. to subscribe to # engine 1's stderr and engine 2's stdout: # sub.setsockopt(zmq.SUBSCRIBE, b'engine.1.stderr') # sub.setsockopt(zmq.SUBSCRIBE, b'engine.2.stdout') sub.connect(iopub_url) while True: try: idents,msg = session.recv(sub, mode=0) except KeyboardInterrupt: return # ident always length 1 here topic = idents[0] if msg['msg_type'] == 'stream': # stdout/stderr # stream names are in msg['content']['name'], if you want to handle # them differently print("%s: %s" % (topic, msg['content']['data'])) elif msg['msg_type'] == 'pyerr': # Python traceback c = msg['content'] print(topic + ':') for line in c['traceback']: # indent lines print(' ' + line)
def test_load_connection_file_session(): """test load_connection_file() after """ session = Session() app = DummyConsoleApp(session=Session()) app.initialize(argv=[]) session = app.session with TemporaryDirectory() as d: cf = os.path.join(d, 'kernel.json') connect.write_connection_file(cf, **sample_info) app.connection_file = cf app.load_connection_file() nt.assert_equal(session.key, sample_info['key']) nt.assert_equal(session.signature_scheme, sample_info['signature_scheme'])
def start_notebook(url, port, user): hub_url = 'https://%s:%s/hub' % (url, port) user_url = 'https://%s:%s/user/%s' % (url, port, user) cookies = login(hub_url, user, user) api = NBAPI(url=user_url, cookies=cookies) path = 'Hello.ipynb' for i in itertools.count(): gen_log.info("loading %s (%s)", user, i) nb = api.get_notebook(path) gen_log.info("starting %s (%s)", user, i) session = Session() kernel = yield api.new_kernel(session.session) try: for j in range(20): gen_log.info("running %s (%s:%s)", user, j, i) yield run_notebook(nb, kernel, session) yield sleep(0.05) gen_log.info("saving %s (%s)", user, i) api.save_notebook(nb, path) finally: api.kill_kernel(kernel['id']) gen_log.info("history: %s", response.history)
def main(connection_file): """watch iopub channel, and print messages""" ctx = zmq.Context.instance() with open(connection_file) as f: cfg = json.loads(f.read()) reg_url = cfg['interface'] iopub_port = cfg['iopub'] iopub_url = "%s:%s" % (reg_url, iopub_port) session = Session(key=str_to_bytes(cfg['key'])) sub = ctx.socket(zmq.SUB) # This will subscribe to all messages: sub.setsockopt(zmq.SUBSCRIBE, b'') # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes # to everything from engine 1, but there is no way to subscribe to # just stdout from everyone. # multiple calls to subscribe will add subscriptions, e.g. to subscribe to # engine 1's stderr and engine 2's stdout: # sub.setsockopt(zmq.SUBSCRIBE, b'engine.1.stderr') # sub.setsockopt(zmq.SUBSCRIBE, b'engine.2.stdout') sub.connect(iopub_url) while True: try: idents, msg = session.recv(sub, mode=0) except KeyboardInterrupt: return # ident always length 1 here topic = idents[0] if msg['msg_type'] == 'stream': # stdout/stderr # stream names are in msg['content']['name'], if you want to handle # them differently print("%s: %s" % (topic, msg['content']['text'])) elif msg['msg_type'] == 'pyerr': # Python traceback c = msg['content'] print(topic + ':') for line in c['traceback']: # indent lines print(' ' + line)
def open(self, kernel_id): # Check to see that origin matches host directly, including ports if not self.same_origin(): self.log.warn("Cross Origin WebSocket Attempt.") raise web.HTTPError(404) self.kernel_id = cast_unicode(kernel_id, 'ascii') self.session = Session(config=self.config) self.save_on_message = self.on_message self.on_message = self.on_first_message
def open(self, kernel_id): self.kernel_id = cast_unicode(kernel_id, 'ascii') # Check to see that origin matches host directly, including ports # Tornado 4 already does CORS checking if tornado.version_info[0] < 4: if not self.check_origin(self.get_origin()): raise web.HTTPError(403) self.session = Session(config=self.config) self.save_on_message = self.on_message self.on_message = self.on_first_message
def open(self, kernel_id): self.kernel_id = kernel_id.decode('ascii') try: cfg = self.application.config except AttributeError: # protect from the case where this is run from something other than # the notebook app: cfg = None self.session = Session(config=cfg) self.save_on_message = self.on_message self.on_message = self.on_first_message
def _setup_session(self, reply, comp_id, timeout=None): """ Set up the kernel information contained in the untrusted reply message `reply` from computer `comp_id`. """ reply_content = reply["content"] kernel_id = reply_content["kernel_id"] kernel_connection = reply_content["connection"] self._kernels[kernel_id] = {"comp_id": comp_id, "connection": kernel_connection, "executing": 0, # number of active execute_requests "timeout": timeout if timeout is not None else time.time()+self.kernel_timeout} self._comps[comp_id]["kernels"][kernel_id] = None self._sessions[kernel_id] = Session(key=kernel_connection["key"])
def open_run_save(api, path): """open a notebook, run it, and save. Only the original notebook is saved, the output is not recorded. """ nb = api.get_notebook(path) session = Session() kernel = yield api.new_kernel(session.session) try: yield run_notebook(nb, kernel, session) finally: api.kill_kernel(kernel['id']) gen_log.info("Saving %s/notebooks/%s", api.url, path) api.save_notebook(nb, path)
def open(self, kernel_id): self.kernel_id = cast_unicode(kernel_id, 'ascii') # Check to see that origin matches host directly, including ports # Tornado 4 already does CORS checking if tornado.version_info[0] < 4: if not self.check_origin(self.get_origin()): self.log.warn("Cross Origin WebSocket Attempt from %s", self.get_origin()) raise web.HTTPError(403) self.session = Session(config=self.config) self.save_on_message = self.on_message self.on_message = self.on_first_message self.ping_callback = ioloop.PeriodicCallback(self.send_ping, WS_PING_INTERVAL) self.ping_callback.start()
def open(self, kernel_id): self.kernel_id = cast_unicode(kernel_id, 'ascii') # Check to see that origin matches host directly, including ports # Tornado 4 already does CORS checking if tornado.version_info[0] < 4: if not self.check_origin(self.get_origin()): self.log.warn("Cross Origin WebSocket Attempt from %s", self.get_origin()) raise web.HTTPError(403) self.session = Session(config=self.config) self.save_on_message = self.on_message self.on_message = self.on_first_message # start the pinging if self.ping_interval > 0: self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping self.last_pong = self.last_ping self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval) self.ping_callback.start()
def open(self, kernel_id): self.kernel_id = kernel_id.decode('ascii') self.session = Session(config=self.config) self.save_on_message = self.on_message self.on_message = self.on_first_message
def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16)
class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): raise NotImplementedError def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): before = self.db.get_history() self.load_records(5) after = self.db.get_history() self.assertEqual(len(after), len(before)+5) self.assertEqual(after[:-5],before) def test_drop_record(self): msg_id = self.load_records()[-1] rec = self.db.get_record(msg_id) self.db.drop_record(msg_id) self.assertRaises(KeyError,self.db.get_record, msg_id) def _round_to_millisecond(self, dt): """necessary because mongodb rounds microseconds""" micro = dt.microsecond extra = int(str(micro)[-3:]) return dt - timedelta(microseconds=extra) def test_update_record(self): now = self._round_to_millisecond(datetime.now()) # msg_id = self.db.get_history()[-1] rec1 = self.db.get_record(msg_id) data = {'stdout': 'hello there', 'completed' : now} self.db.update_record(msg_id, data) rec2 = self.db.get_record(msg_id) self.assertEqual(rec2['stdout'], 'hello there') self.assertEqual(rec2['completed'], now) rec1.update(data) self.assertEqual(rec1, rec2) # def test_update_record_bad(self): # """test updating nonexistant records""" # msg_id = str(uuid.uuid4()) # data = {'stdout': 'hello there'} # self.assertRaises(KeyError, self.db.update_record, msg_id, data) def test_find_records_dt(self): """test finding records by date""" hist = self.db.get_history() middle = self.db.get_record(hist[len(hist)//2]) tic = middle['submitted'] before = self.db.find_records({'submitted' : {'$lt' : tic}}) after = self.db.find_records({'submitted' : {'$gte' : tic}}) self.assertEqual(len(before)+len(after),len(hist)) for b in before: self.assertTrue(b['submitted'] < tic) for a in after: self.assertTrue(a['submitted'] >= tic) same = self.db.find_records({'submitted' : tic}) for s in same: self.assertTrue(s['submitted'] == tic) def test_find_records_keys(self): """test extracting subset of record keys""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) def test_find_records_msg_id(self): """ensure msg_id is always in found records""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) for rec in found: self.assertTrue('msg_id' in rec.keys()) def test_find_records_in(self): """test finding records with '$in','$nin' operators""" hist = self.db.get_history() even = hist[::2] odd = hist[1::2] recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(even), set(found)) recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(odd), set(found)) def test_get_history(self): msg_ids = self.db.get_history() latest = datetime(1984,1,1) for msg_id in msg_ids: rec = self.db.get_record(msg_id) newt = rec['submitted'] self.assertTrue(newt >= latest) latest = newt msg_id = self.load_records(1)[-1] self.assertEqual(self.db.get_history()[-1],msg_id) def test_datetime(self): """get/set timestamps with datetime objects""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['submitted'], datetime)) self.db.update_record(msg_id, dict(completed=datetime.now())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) def test_drop_matching(self): msg_ids = self.load_records(10) query = {'msg_id' : {'$in':msg_ids}} self.db.drop_matching_records(query) recs = self.db.find_records(query) self.assertEqual(len(recs), 0) def test_null(self): """test None comparison queries""" msg_ids = self.load_records(10) query = {'msg_id' : None} recs = self.db.find_records(query) self.assertEqual(len(recs), 0) query = {'msg_id' : {'$ne' : None}} recs = self.db.find_records(query) self.assertTrue(len(recs) >= 10) def test_pop_safe_get(self): """editing query results shouldn't affect record [get]""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.get_record(msg_id) self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find(self): """editing query results shouldn't affect record [find]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id})[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find_keys(self): """editing query results shouldn't affect record [find+keys]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id)
def _session_default(self): from IPython.kernel.zmq.session import Session return Session(config=self.config)
def setUp(self): self.session = Session()
def _session_default(self): # don't sign in-process messages return Session(key=b'', parent=self)
def _session_default(self): from IPython.kernel.zmq.session import Session return Session(parent=self, key=b'')
# https://github.com/ipython/ipython/issues/680 assert isinstance(threading.currentThread(), threading._MainThread) try: connection_file = "kernel-%s.json" % os.getpid() def cleanup_connection_file(): try: os.remove(connection_file) except (IOError, OSError): pass atexit.register(cleanup_connection_file) logger = logging.Logger("IPython") logger.addHandler(logging.NullHandler()) session = Session(username=u'kernel') context = zmq.Context.instance() ip = socket.gethostbyname(socket.gethostname()) transport = "tcp" addr = "%s://%s" % (transport, ip) shell_socket = context.socket(zmq.ROUTER) shell_port = shell_socket.bind_to_random_port(addr) iopub_socket = context.socket(zmq.PUB) iopub_port = iopub_socket.bind_to_random_port(addr) control_socket = context.socket(zmq.ROUTER) control_port = control_socket.bind_to_random_port(addr) hb_ctx = zmq.Context() heartbeat = Heartbeat(hb_ctx, (transport, ip, 0)) hb_port = heartbeat.port
def initIPythonKernel(): # You can remotely connect to this kernel. See the output on stdout. try: import IPython.kernel.zmq.ipkernel from IPython.kernel.zmq.ipkernel import Kernel from IPython.kernel.zmq.heartbeat import Heartbeat from IPython.kernel.zmq.session import Session from IPython.kernel import write_connection_file import zmq from zmq.eventloop import ioloop from zmq.eventloop.zmqstream import ZMQStream IPython.kernel.zmq.ipkernel.signal = lambda sig, f: None # Overwrite. except ImportError as e: print("IPython import error, cannot start IPython kernel. %s" % e) return import atexit import socket import logging import threading # Do in mainthread to avoid history sqlite DB errors at exit. # https://github.com/ipython/ipython/issues/680 assert isinstance(threading.currentThread(), threading._MainThread) try: ip = socket.gethostbyname(socket.gethostname()) connection_file = "ipython-kernel-%s-%s.json" % (ip, os.getpid()) def cleanup_connection_file(): try: os.remove(connection_file) except (IOError, OSError): pass atexit.register(cleanup_connection_file) logger = logging.Logger("IPython") logger.addHandler(logging.NullHandler()) session = Session(username=u'kernel') context = zmq.Context.instance() transport = "tcp" addr = "%s://%s" % (transport, ip) shell_socket = context.socket(zmq.ROUTER) shell_port = shell_socket.bind_to_random_port(addr) iopub_socket = context.socket(zmq.PUB) iopub_port = iopub_socket.bind_to_random_port(addr) control_socket = context.socket(zmq.ROUTER) control_port = control_socket.bind_to_random_port(addr) hb_ctx = zmq.Context() heartbeat = Heartbeat(hb_ctx, (transport, ip, 0)) hb_port = heartbeat.port heartbeat.start() shell_stream = ZMQStream(shell_socket) control_stream = ZMQStream(control_socket) kernel = Kernel(session=session, shell_streams=[shell_stream, control_stream], iopub_socket=iopub_socket, log=logger) write_connection_file(connection_file, shell_port=shell_port, iopub_port=iopub_port, control_port=control_port, hb_port=hb_port, ip=ip) #print "To connect another client to this IPython kernel, use:", \ # "ipython console --existing %s" % connection_file except Exception as e: print("Exception while initializing IPython ZMQ kernel. %s" % e) return def ipython_thread(): kernel.start() try: ioloop.IOLoop.instance().start() except KeyboardInterrupt: pass thread = threading.Thread(target=ipython_thread, name="IPython kernel") thread.daemon = True thread.start()
def open(self, kernel_id): self.kernel_id = cast_unicode(kernel_id, 'ascii') self.session = Session(config=self.config) self.save_on_message = self.on_message self.on_message = self.on_first_message
def test_serialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [memoryview(os.urandom(3)) for i in range(3)] bmsg = serialize_binary_message(msg) nt.assert_is_instance(bmsg, bytes)
def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config)
def test_serialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ os.urandom(3) for i in range(3) ] bmsg = serialize_binary_message(msg) nt.assert_is_instance(bmsg, bytes)
def init_session(self): """create our session object""" default_secure(self.config) self.session = Session(parent=self, username=u'kernel')
class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): raise NotImplementedError def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): before = self.db.get_history() self.load_records(5) after = self.db.get_history() self.assertEqual(len(after), len(before) + 5) self.assertEqual(after[:-5], before) def test_drop_record(self): msg_id = self.load_records()[-1] rec = self.db.get_record(msg_id) self.db.drop_record(msg_id) self.assertRaises(KeyError, self.db.get_record, msg_id) def _round_to_millisecond(self, dt): """necessary because mongodb rounds microseconds""" micro = dt.microsecond extra = int(str(micro)[-3:]) return dt - timedelta(microseconds=extra) def test_update_record(self): now = self._round_to_millisecond(datetime.now()) # msg_id = self.db.get_history()[-1] rec1 = self.db.get_record(msg_id) data = {'stdout': 'hello there', 'completed': now} self.db.update_record(msg_id, data) rec2 = self.db.get_record(msg_id) self.assertEqual(rec2['stdout'], 'hello there') self.assertEqual(rec2['completed'], now) rec1.update(data) self.assertEqual(rec1, rec2) # def test_update_record_bad(self): # """test updating nonexistant records""" # msg_id = str(uuid.uuid4()) # data = {'stdout': 'hello there'} # self.assertRaises(KeyError, self.db.update_record, msg_id, data) def test_find_records_dt(self): """test finding records by date""" hist = self.db.get_history() middle = self.db.get_record(hist[len(hist) // 2]) tic = middle['submitted'] before = self.db.find_records({'submitted': {'$lt': tic}}) after = self.db.find_records({'submitted': {'$gte': tic}}) self.assertEqual(len(before) + len(after), len(hist)) for b in before: self.assertTrue(b['submitted'] < tic) for a in after: self.assertTrue(a['submitted'] >= tic) same = self.db.find_records({'submitted': tic}) for s in same: self.assertTrue(s['submitted'] == tic) def test_find_records_keys(self): """test extracting subset of record keys""" found = self.db.find_records({'msg_id': { '$ne': '' }}, keys=['submitted', 'completed']) for rec in found: self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) def test_find_records_msg_id(self): """ensure msg_id is always in found records""" found = self.db.find_records({'msg_id': { '$ne': '' }}, keys=['submitted', 'completed']) for rec in found: self.assertTrue('msg_id' in list(rec.keys())) found = self.db.find_records({'msg_id': { '$ne': '' }}, keys=['submitted']) for rec in found: self.assertTrue('msg_id' in list(rec.keys())) found = self.db.find_records({'msg_id': {'$ne': ''}}, keys=['msg_id']) for rec in found: self.assertTrue('msg_id' in list(rec.keys())) def test_find_records_in(self): """test finding records with '$in','$nin' operators""" hist = self.db.get_history() even = hist[::2] odd = hist[1::2] recs = self.db.find_records({'msg_id': {'$in': even}}) found = [r['msg_id'] for r in recs] self.assertEqual(set(even), set(found)) recs = self.db.find_records({'msg_id': {'$nin': even}}) found = [r['msg_id'] for r in recs] self.assertEqual(set(odd), set(found)) def test_get_history(self): msg_ids = self.db.get_history() latest = datetime(1984, 1, 1) for msg_id in msg_ids: rec = self.db.get_record(msg_id) newt = rec['submitted'] self.assertTrue(newt >= latest) latest = newt msg_id = self.load_records(1)[-1] self.assertEqual(self.db.get_history()[-1], msg_id) def test_datetime(self): """get/set timestamps with datetime objects""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['submitted'], datetime)) self.db.update_record(msg_id, dict(completed=datetime.now())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) def test_drop_matching(self): msg_ids = self.load_records(10) query = {'msg_id': {'$in': msg_ids}} self.db.drop_matching_records(query) recs = self.db.find_records(query) self.assertEqual(len(recs), 0) def test_null(self): """test None comparison queries""" msg_ids = self.load_records(10) query = {'msg_id': None} recs = self.db.find_records(query) self.assertEqual(len(recs), 0) query = {'msg_id': {'$ne': None}} recs = self.db.find_records(query) self.assertTrue(len(recs) >= 10) def test_pop_safe_get(self): """editing query results shouldn't affect record [get]""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.get_record(msg_id) self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find(self): """editing query results shouldn't affect record [find]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id': msg_id})[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id': msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find_keys(self): """editing query results shouldn't affect record [find+keys]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id': msg_id}, keys=['buffers', 'header'])[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id': msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id)
def initialize(self): self.session = Session(config=self.config)