def test_default_version(): s = Session() msg = s.msg("msg_type") msg['header'].pop('version') original = copy.deepcopy(msg) adapted = adapt(original) nt.assert_equal(adapted['header']['version'], V4toV5.version)
def test_default_version(): s = Session() msg = s.msg("msg_type") msg["header"].pop("version") original = copy.deepcopy(msg) adapted = adapt(original) assert adapted["header"]["version"] == V4toV5.version
def test_deserialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) assert msg2 == msg
def test_deserialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) nt.assert_equal(msg2, msg)
def test_deserialize_binary(): s = Session() msg = s.msg("data_pub", content={"a": "b"}) msg["buffers"] = [memoryview(os.urandom(2)) for i in range(3)] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) assert msg2 == msg
def __init__(self, *args, **kwargs): super(WidgetHandler, self).__init__(*args, **kwargs) self.context = zmq.Context() self.socket = self.context.socket(zmq.XREQ) self.socket.linger = 1000 self.socket.connect("tcp://127.0.0.1:%s" % port) self.session = Session(key=session_key.encode())
def main(connection_file): """watch iopub channel, and print messages""" ctx = zmq.Context.instance() with open(connection_file) as f: cfg = json.loads(f.read()) # reg_url = cfg['interface'] reg_url = 'tcp://140.117.168.49' iopub_port = cfg['iopub'] iopub_url = "{}:{}".format(reg_url, iopub_port) print("iopub_url:", iopub_url) session = Session(key=cfg['key'].encode('ascii')) sub = ctx.socket(zmq.SUB) # This will subscribe to all messages: sub.SUBSCRIBE = b'' # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes # to everything from engine 1, but there is no way to subscribe to # just stdout from everyone. # multiple calls to subscribe will add subscriptions, e.g. to subscribe to # engine 1's stderr and engine 2's stdout: # sub.SUBSCRIBE = b'engine.1.stderr' # sub.SUBSCRIBE = b'engine.2.stdout' sub.connect(iopub_url) while True: try: idents, msg = session.recv(sub, mode=0) except KeyboardInterrupt: return # ident always length 1 here topic = idents[0].decode('utf8', 'replace') if msg['msg_type'] == 'stream': # stdout/stderr # stream names are in msg['content']['name'], if you want to handle # them differently print("{}: {}".format(topic, msg['content']['text'])) elif msg['msg_type'] == 'error': # Python traceback c = msg['content'] print(topic + ':') for line in c['traceback']: # indent lines print(' ' + line) elif msg['msg_type'] == 'error': # Python traceback c = msg['content'] print(topic + ':') for line in c['traceback']: # indent lines print(' ' + line)
def __init__(self, kernel_id, signature_key, executing_kernel_client_settings): super(ExecutingKernelClient, self).__init__() self.client_settings = executing_kernel_client_settings self.kernel_id = kernel_id self.context = zmq.Context() self.session = Session(key=signature_key) self.subscriber = {} self._rabbit_sender_client, self._rabbit_listener = self._init_rabbit_clients( ) self._socket_forwarders = self._init_socket_forwarders()
def main(connection_file): """watch iopub channel, and print messages""" ctx = zmq.Context.instance() with open(connection_file) as f: cfg = json.loads(f.read()) reg_url = cfg['interface'] iopub_port = cfg['iopub'] iopub_url = "%s:%s"%(reg_url, iopub_port) session = Session(key=cfg['key'].encode('ascii')) sub = ctx.socket(zmq.SUB) # This will subscribe to all messages: sub.SUBSCRIBE = b'' # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes # to everything from engine 1, but there is no way to subscribe to # just stdout from everyone. # multiple calls to subscribe will add subscriptions, e.g. to subscribe to # engine 1's stderr and engine 2's stdout: # sub.SUBSCRIBE = b'engine.1.stderr' # sub.SUBSCRIBE = b'engine.2.stdout' sub.connect(iopub_url) while True: try: idents,msg = session.recv(sub, mode=0) except KeyboardInterrupt: return # ident always length 1 here topic = idents[0].decode('utf8', 'replace') if msg['msg_type'] == 'stream': # stdout/stderr # stream names are in msg['content']['name'], if you want to handle # them differently print("%s: %s" % (topic, msg['content']['text'])) elif msg['msg_type'] == 'error': # Python traceback c = msg['content'] print(topic + ':') for line in c['traceback']: # indent lines print(' ' + line) elif msg['msg_type'] == 'error': # Python traceback c = msg['content'] print(topic + ':') for line in c['traceback']: # indent lines print(' ' + line)
def load_connection_file(self): """load config from a JSON connector file, at a *lower* priority than command-line/config files. Same content can be specified in $IPP_CONNECTION_INFO env """ config = self.config if self.connection_info_env: self.log.info("Loading connection info from $IPP_CONNECTION_INFO") d = json.loads(self.connection_info_env) else: self.log.info("Loading connection file %r", self.url_file) with open(self.url_file) as f: d = json.load(f) # allow hand-override of location for disambiguation # and ssh-server if 'IPEngine.location' not in self.cli_config: self.location = d['location'] if 'ssh' in d and not self.sshserver: self.sshserver = d.get("ssh") proto, ip = d['interface'].split('://') ip = disambiguate_ip_address(ip, self.location) d['interface'] = f'{proto}://{ip}' if d.get('curve_serverkey'): # connection file takes precedence over env, if present and defined self.curve_serverkey = d['curve_serverkey'].encode('ascii') if self.curve_serverkey: self.log.info("Using CurveZMQ security") self._ensure_curve_keypair() else: self.log.warning("Not using CurveZMQ security") # DO NOT allow override of basic URLs, serialization, or key # JSON file takes top priority there if d.get('key') or 'key' not in config.Session: config.Session.key = d.get('key', '').encode('utf8') config.Session.signature_scheme = d['signature_scheme'] self.registration_url = f"{d['interface']}:{d['registration']}" config.Session.packer = d['pack'] config.Session.unpacker = d['unpack'] self.session = Session(parent=self) self.log.debug("Config changed:") self.log.debug("%r", config) self.connection_info = d
def test_load_connection_file_session_with_kn(): """test load_connection_file() after """ session = Session() app = DummyConsoleApp(session=Session()) app.initialize(argv=[]) session = app.session with TemporaryDirectory() as d: cf = os.path.join(d, 'kernel.json') connect.write_connection_file(cf, **sample_info_kn) app.connection_file = cf app.load_connection_file() assert session.key == sample_info_kn['key'] assert session.signature_scheme == sample_info_kn['signature_scheme']
def test_io_api(): """Test that wrapped stdout has the same API as a normal TextIO object""" session = Session() ctx = zmq.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) thread.start() stream = OutStream(session, thread, 'stdout') # cleanup unused zmq objects before we start testing thread.stop() thread.close() ctx.term() assert stream.errors is None assert not stream.isatty() with nt.assert_raises(io.UnsupportedOperation): stream.detach() with nt.assert_raises(io.UnsupportedOperation): next(stream) with nt.assert_raises(io.UnsupportedOperation): stream.read() with nt.assert_raises(io.UnsupportedOperation): stream.readline() with nt.assert_raises(io.UnsupportedOperation): stream.seek() with nt.assert_raises(io.UnsupportedOperation): stream.tell()
def start_watching_activity(self, kernel_id): """Start watching IOPub messages on a kernel for activity. - update last_activity on every message - record execution_state from status messages """ kernel = self._kernels[kernel_id] # add busy/activity markers: kernel.execution_state = 'starting' kernel.last_activity = utcnow() kernel._activity_stream = kernel.connect_iopub() session = Session( config=kernel.session.config, key=kernel.session.key, ) def record_activity(msg_list): """Record an IOPub message arriving from a kernel""" kernel.last_activity = utcnow() idents, fed_msg_list = session.feed_identities(msg_list) msg = session.deserialize(fed_msg_list) msg_type = msg['header']['msg_type'] self.log.debug("activity on %s: %s", kernel_id, msg_type) if msg_type == 'status': kernel.execution_state = msg['content']['execution_state'] kernel._activity_stream.on_recv(record_activity)
def test_io_isatty(): session = Session() ctx = zmq.Context() pub = ctx.socket(zmq.PUB) thread = IOPubThread(pub) thread.start() stream = OutStream(session, thread, 'stdout', isatty=True) assert stream.isatty()
class WidgetHandler(FileSystemEventHandler): def __init__(self, *args, **kwargs): super(WidgetHandler, self).__init__(*args, **kwargs) self.context = zmq.Context() self.socket = self.context.socket(zmq.XREQ) self.socket.linger = 1000 self.socket.connect("tcp://127.0.0.1:%s" % port) self.session = Session(key=session_key.encode()) def on_any_event(self, event): msg = self.msg(event) self.session.send(self.socket, msg) def msg(self, event): return { "buffers": [], "channel": "shell", "metadata": {}, "parent_header": {}, "header": { "msg_id": str(uuid4()), "username": "******", "session": session, "msg_type": "comm_msg", "version": "5.0" }, "content": { "comm_id": model_id, "data": { "method": "custom", "content": { "event": event.event_type, "is_directory": event.is_directory, "src_path": event.src_path, "dest_path": getattr(event, "dest_path", None) } }, }, }
def _create_session(self): from jupyter_client.session import Session try: from jupyter_client.session import new_id_bytes except ImportError: def new_id_bytes(): import uuid return uuid.uuid4() self._session = Session(username=u'kernel', key=new_id_bytes())
def create_shell(username, session_id, key): """Instantiates a CapturingSocket and SwiftShell and hooks them up. After you call this, the returned CapturingSocket should capture all IPython display messages. """ socket = CapturingSocket() session = Session(username=username, session=session_id, key=key) shell = SwiftShell.instance() shell.display_pub.session = session shell.display_pub.pub_socket = socket return (socket, shell)
def _setup_session(self, reply, comp_id): """ Set up the kernel information contained in the untrusted reply message `reply` from computer `comp_id`. """ reply_content = reply["content"] kernel_id = reply_content["kernel_id"] kernel_connection = reply_content["connection"] self._kernels[kernel_id] = { "comp_id": comp_id, "connection": kernel_connection, "executing": 0 } self._comps[comp_id]["kernels"][kernel_id] = None self._sessions[kernel_id] = Session(key=kernel_connection["key"])
def __init__( self, *, pid: int, engine_id: int, control_url: str, registration_url: str, identity: bytes, curve_serverkey: bytes, curve_publickey: bytes, curve_secretkey: bytes, config: Config, pipe, log_level: int = logging.INFO, ): self.pid = pid self.engine_id = engine_id self.parent_process = psutil.Process(self.pid) self.control_url = control_url self.registration_url = registration_url self.identity = identity self.curve_serverkey = curve_serverkey self.curve_publickey = curve_publickey self.curve_secretkey = curve_secretkey self.config = config self.pipe = pipe self.session = Session(config=self.config) self.log = local_logger(f"{self.__class__.__name__}.{engine_id}", log_level) self.log.propagate = False self.control_handlers = { "signal_request": self.signal_request, } self._finish_called = False
def _setup_session(self, reply, comp_id, timeout=None): """ Set up the kernel information contained in the untrusted reply message `reply` from computer `comp_id`. """ reply_content = reply["content"] kernel_id = reply_content["kernel_id"] kernel_connection = reply_content["connection"] if timeout is None : timeout = self.max_kernel_timeout self._kernels[kernel_id] = {"comp_id": comp_id, "connection": kernel_connection, "executing": 0, # number of active execute_requests "deadline": time.time()+timeout, "timeout": timeout} self._comps[comp_id]["kernels"][kernel_id] = None self._sessions[kernel_id] = Session(key=kernel_connection["key"])
def create_on_reply(self, kernel_id): """Creates an anonymous function to handle reply messages from the kernel. Parameters ---------- kernel_id Kernel to listen to Returns ------- function Callback function taking a kernel ID and 0mq message list """ kernel = self.kernel_clients[kernel_id] session = Session( config=kernel.session.config, key=kernel.session.key, ) return lambda msg_list: self._on_reply(kernel_id, session, msg_list)
def start_watching_activity(self, kernel_id): """Start watching IOPub messages on a kernel for activity. - update last_activity on every message - record execution_state from status messages """ kernel = self._kernels[kernel_id] # add busy/activity markers: kernel.execution_state = "starting" kernel.reason = "" kernel.last_activity = utcnow() kernel._activity_stream = kernel.connect_iopub() session = Session( config=kernel.session.config, key=kernel.session.key, ) def record_activity(msg_list): """Record an IOPub message arriving from a kernel""" self.last_kernel_activity = kernel.last_activity = utcnow() idents, fed_msg_list = session.feed_identities(msg_list) msg = session.deserialize(fed_msg_list, content=False) msg_type = msg["header"]["msg_type"] if msg_type == "status": msg = session.deserialize(fed_msg_list) kernel.execution_state = msg["content"]["execution_state"] self.log.debug( "activity on %s: %s (%s)", kernel_id, msg_type, kernel.execution_state, ) else: self.log.debug("activity on %s: %s", kernel_id, msg_type) kernel._activity_stream.on_recv(record_activity)
def test_serialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ] bmsg = serialize_binary_message(msg) nt.assert_is_instance(bmsg, bytes)
def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config) self.gateway = GatewayWebSocketClient( gateway_url=GatewayClient.instance().url)
def setUp(self): self.session = Session()
class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): raise NotImplementedError def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): before = self.db.get_history() self.load_records(5) after = self.db.get_history() self.assertEqual(len(after), len(before)+5) self.assertEqual(after[:-5],before) def test_drop_record(self): msg_id = self.load_records()[-1] rec = self.db.get_record(msg_id) self.db.drop_record(msg_id) self.assertRaises(KeyError,self.db.get_record, msg_id) def _round_to_millisecond(self, dt): """necessary because mongodb rounds microseconds""" micro = dt.microsecond extra = int(str(micro)[-3:]) return dt - timedelta(microseconds=extra) def test_update_record(self): now = self._round_to_millisecond(datetime.now()) # msg_id = self.db.get_history()[-1] rec1 = self.db.get_record(msg_id) data = {'stdout': 'hello there', 'completed' : now} self.db.update_record(msg_id, data) rec2 = self.db.get_record(msg_id) self.assertEqual(rec2['stdout'], 'hello there') self.assertEqual(rec2['completed'], now) rec1.update(data) self.assertEqual(rec1, rec2) # def test_update_record_bad(self): # """test updating nonexistant records""" # msg_id = str(uuid.uuid4()) # data = {'stdout': 'hello there'} # self.assertRaises(KeyError, self.db.update_record, msg_id, data) def test_find_records_dt(self): """test finding records by date""" hist = self.db.get_history() middle = self.db.get_record(hist[len(hist)//2]) tic = middle['submitted'] before = self.db.find_records({'submitted' : {'$lt' : tic}}) after = self.db.find_records({'submitted' : {'$gte' : tic}}) self.assertEqual(len(before)+len(after),len(hist)) for b in before: self.assertTrue(b['submitted'] < tic) for a in after: self.assertTrue(a['submitted'] >= tic) same = self.db.find_records({'submitted' : tic}) for s in same: self.assertTrue(s['submitted'] == tic) def test_find_records_keys(self): """test extracting subset of record keys""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) def test_find_records_msg_id(self): """ensure msg_id is always in found records""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) for rec in found: self.assertTrue('msg_id' in rec.keys()) def test_find_records_in(self): """test finding records with '$in','$nin' operators""" hist = self.db.get_history() even = hist[::2] odd = hist[1::2] recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(even), set(found)) recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(odd), set(found)) def test_get_history(self): msg_ids = self.db.get_history() latest = datetime(1984,1,1) for msg_id in msg_ids: rec = self.db.get_record(msg_id) newt = rec['submitted'] self.assertTrue(newt >= latest) latest = newt msg_id = self.load_records(1)[-1] self.assertEqual(self.db.get_history()[-1],msg_id) def test_datetime(self): """get/set timestamps with datetime objects""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['submitted'], datetime)) self.db.update_record(msg_id, dict(completed=datetime.now())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) def test_drop_matching(self): msg_ids = self.load_records(10) query = {'msg_id' : {'$in':msg_ids}} self.db.drop_matching_records(query) recs = self.db.find_records(query) self.assertEqual(len(recs), 0) def test_null(self): """test None comparison queries""" msg_ids = self.load_records(10) query = {'msg_id' : None} recs = self.db.find_records(query) self.assertEqual(len(recs), 0) query = {'msg_id' : {'$ne' : None}} recs = self.db.find_records(query) self.assertTrue(len(recs) >= 10) def test_pop_safe_get(self): """editing query results shouldn't affect record [get]""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.get_record(msg_id) self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find(self): """editing query results shouldn't affect record [find]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id})[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find_keys(self): """editing query results shouldn't affect record [find+keys]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id)
def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config) # TODO: make kernel client class configurable self.gateway = KernelGatewayWSClient()
def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16)
def _session_default(self): # don't sign in-process messages return Session(key=b'', parent=self)
def _session_default(self): from jupyter_client.session import Session return Session(parent=self)
def _default_session(self): from jupyter_client.session import Session return Session(parent=self, key=INPROCESS_KEY)
class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): raise NotImplementedError def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): before = self.db.get_history() self.load_records(5) after = self.db.get_history() self.assertEqual(len(after), len(before)+5) self.assertEqual(after[:-5],before) def test_drop_record(self): msg_id = self.load_records()[-1] rec = self.db.get_record(msg_id) self.db.drop_record(msg_id) self.assertRaises(KeyError,self.db.get_record, msg_id) def _round_to_millisecond(self, dt): """necessary because mongodb rounds microseconds""" micro = dt.microsecond extra = int(str(micro)[-3:]) return dt - timedelta(microseconds=extra) def test_update_record(self): now = self._round_to_millisecond(util.utcnow()) msg_id = self.db.get_history()[-1] rec1 = self.db.get_record(msg_id) data = {'stdout': 'hello there', 'completed' : now} self.db.update_record(msg_id, data) rec2 = self.db.get_record(msg_id) self.assertEqual(rec2['stdout'], 'hello there') self.assertEqual(rec2['completed'], now) rec1.update(data) self.assertEqual(rec1, rec2) # def test_update_record_bad(self): # """test updating nonexistant records""" # msg_id = str(uuid.uuid4()) # data = {'stdout': 'hello there'} # self.assertRaises(KeyError, self.db.update_record, msg_id, data) def test_find_records_dt(self): """test finding records by date""" hist = self.db.get_history() middle = self.db.get_record(hist[len(hist)//2]) tic = middle['submitted'] before = self.db.find_records({'submitted' : {'$lt' : tic}}) after = self.db.find_records({'submitted' : {'$gte' : tic}}) self.assertEqual(len(before)+len(after),len(hist)) for b in before: self.assertLess(b['submitted'], tic) for a in after: self.assertGreaterEqual(a['submitted'], tic) same = self.db.find_records({'submitted' : tic}) for s in same: self.assertEqual(s['submitted'], tic) def test_find_records_keys(self): """test extracting subset of record keys""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) def test_find_records_msg_id(self): """ensure msg_id is always in found records""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) for rec in found: self.assertTrue('msg_id' in rec.keys()) def test_find_records_in(self): """test finding records with '$in','$nin' operators""" hist = self.db.get_history() even = hist[::2] odd = hist[1::2] recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(even), set(found)) recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(odd), set(found)) def test_get_history(self): msg_ids = self.db.get_history() latest = datetime(1984,1,1).replace(tzinfo=utc) for msg_id in msg_ids: rec = self.db.get_record(msg_id) newt = rec['submitted'] self.assertTrue(newt >= latest) latest = newt msg_id = self.load_records(1)[-1] self.assertEqual(self.db.get_history()[-1],msg_id) def test_datetime(self): """get/set timestamps with datetime objects""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['submitted'], datetime)) self.db.update_record(msg_id, dict(completed=util.utcnow())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) def test_drop_matching(self): msg_ids = self.load_records(10) query = {'msg_id' : {'$in':msg_ids}} self.db.drop_matching_records(query) recs = self.db.find_records(query) self.assertEqual(len(recs), 0) def test_null(self): """test None comparison queries""" msg_ids = self.load_records(10) query = {'msg_id' : None} recs = self.db.find_records(query) self.assertEqual(len(recs), 0) query = {'msg_id' : {'$ne' : None}} recs = self.db.find_records(query) self.assertTrue(len(recs) >= 10) def test_pop_safe_get(self): """editing query results shouldn't affect record [get]""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.get_record(msg_id) self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find(self): """editing query results shouldn't affect record [find]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id})[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find_keys(self): """editing query results shouldn't affect record [find+keys]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id)
def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config)
def _default_session(self): return Session(parent=self)