def test_default_version(): s = Session() msg = s.msg("msg_type") msg['header'].pop('version') original = copy.deepcopy(msg) adapted = adapt(original) nt.assert_equal(adapted['header']['version'], V4toV5.version)
def test_deserialize_binary(): s = Session() msg = s.msg("data_pub", content={"a": "b"}) msg["buffers"] = [memoryview(os.urandom(2)) for i in range(3)] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) assert msg2 == msg
def test_deserialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) assert msg2 == msg
def test_default_version(): s = Session() msg = s.msg("msg_type") msg["header"].pop("version") original = copy.deepcopy(msg) adapted = adapt(original) assert adapted["header"]["version"] == V4toV5.version
def test_deserialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ] bmsg = serialize_binary_message(msg) msg2 = deserialize_binary_message(bmsg) nt.assert_equal(msg2, msg)
class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): raise NotImplementedError def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): before = self.db.get_history() self.load_records(5) after = self.db.get_history() self.assertEqual(len(after), len(before)+5) self.assertEqual(after[:-5],before) def test_drop_record(self): msg_id = self.load_records()[-1] rec = self.db.get_record(msg_id) self.db.drop_record(msg_id) self.assertRaises(KeyError,self.db.get_record, msg_id) def _round_to_millisecond(self, dt): """necessary because mongodb rounds microseconds""" micro = dt.microsecond extra = int(str(micro)[-3:]) return dt - timedelta(microseconds=extra) def test_update_record(self): now = self._round_to_millisecond(util.utcnow()) msg_id = self.db.get_history()[-1] rec1 = self.db.get_record(msg_id) data = {'stdout': 'hello there', 'completed' : now} self.db.update_record(msg_id, data) rec2 = self.db.get_record(msg_id) self.assertEqual(rec2['stdout'], 'hello there') self.assertEqual(rec2['completed'], now) rec1.update(data) self.assertEqual(rec1, rec2) # def test_update_record_bad(self): # """test updating nonexistant records""" # msg_id = str(uuid.uuid4()) # data = {'stdout': 'hello there'} # self.assertRaises(KeyError, self.db.update_record, msg_id, data) def test_find_records_dt(self): """test finding records by date""" hist = self.db.get_history() middle = self.db.get_record(hist[len(hist)//2]) tic = middle['submitted'] before = self.db.find_records({'submitted' : {'$lt' : tic}}) after = self.db.find_records({'submitted' : {'$gte' : tic}}) self.assertEqual(len(before)+len(after),len(hist)) for b in before: self.assertLess(b['submitted'], tic) for a in after: self.assertGreaterEqual(a['submitted'], tic) same = self.db.find_records({'submitted' : tic}) for s in same: self.assertEqual(s['submitted'], tic) def test_find_records_keys(self): """test extracting subset of record keys""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) def test_find_records_msg_id(self): """ensure msg_id is always in found records""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) for rec in found: self.assertTrue('msg_id' in rec.keys()) def test_find_records_in(self): """test finding records with '$in','$nin' operators""" hist = self.db.get_history() even = hist[::2] odd = hist[1::2] recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(even), set(found)) recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(odd), set(found)) def test_get_history(self): msg_ids = self.db.get_history() latest = datetime(1984,1,1).replace(tzinfo=utc) for msg_id in msg_ids: rec = self.db.get_record(msg_id) newt = rec['submitted'] self.assertTrue(newt >= latest) latest = newt msg_id = self.load_records(1)[-1] self.assertEqual(self.db.get_history()[-1],msg_id) def test_datetime(self): """get/set timestamps with datetime objects""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['submitted'], datetime)) self.db.update_record(msg_id, dict(completed=util.utcnow())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) def test_drop_matching(self): msg_ids = self.load_records(10) query = {'msg_id' : {'$in':msg_ids}} self.db.drop_matching_records(query) recs = self.db.find_records(query) self.assertEqual(len(recs), 0) def test_null(self): """test None comparison queries""" msg_ids = self.load_records(10) query = {'msg_id' : None} recs = self.db.find_records(query) self.assertEqual(len(recs), 0) query = {'msg_id' : {'$ne' : None}} recs = self.db.find_records(query) self.assertTrue(len(recs) >= 10) def test_pop_safe_get(self): """editing query results shouldn't affect record [get]""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.get_record(msg_id) self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find(self): """editing query results shouldn't affect record [find]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id})[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find_keys(self): """editing query results shouldn't affect record [find+keys]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id)
def test_serialize_binary(): s = Session() msg = s.msg("data_pub", content={"a": "b"}) msg["buffers"] = [memoryview(os.urandom(3)) for i in range(3)] bmsg = serialize_binary_message(msg) assert isinstance(bmsg, bytes)
class TaskDBTest: def setUp(self): self.session = Session() self.db = self.create_db() self.load_records(16) def create_db(self): raise NotImplementedError def load_records(self, n=1, buffer_size=100): """load n records for testing""" #sleep 1/10 s, to ensure timestamp is different to previous calls time.sleep(0.1) msg_ids = [] for i in range(n): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [os.urandom(buffer_size)] rec = init_record(msg) msg_id = msg['header']['msg_id'] msg_ids.append(msg_id) self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): before = self.db.get_history() self.load_records(5) after = self.db.get_history() self.assertEqual(len(after), len(before)+5) self.assertEqual(after[:-5],before) def test_drop_record(self): msg_id = self.load_records()[-1] rec = self.db.get_record(msg_id) self.db.drop_record(msg_id) self.assertRaises(KeyError,self.db.get_record, msg_id) def _round_to_millisecond(self, dt): """necessary because mongodb rounds microseconds""" micro = dt.microsecond extra = int(str(micro)[-3:]) return dt - timedelta(microseconds=extra) def test_update_record(self): now = self._round_to_millisecond(datetime.now()) # msg_id = self.db.get_history()[-1] rec1 = self.db.get_record(msg_id) data = {'stdout': 'hello there', 'completed' : now} self.db.update_record(msg_id, data) rec2 = self.db.get_record(msg_id) self.assertEqual(rec2['stdout'], 'hello there') self.assertEqual(rec2['completed'], now) rec1.update(data) self.assertEqual(rec1, rec2) # def test_update_record_bad(self): # """test updating nonexistant records""" # msg_id = str(uuid.uuid4()) # data = {'stdout': 'hello there'} # self.assertRaises(KeyError, self.db.update_record, msg_id, data) def test_find_records_dt(self): """test finding records by date""" hist = self.db.get_history() middle = self.db.get_record(hist[len(hist)//2]) tic = middle['submitted'] before = self.db.find_records({'submitted' : {'$lt' : tic}}) after = self.db.find_records({'submitted' : {'$gte' : tic}}) self.assertEqual(len(before)+len(after),len(hist)) for b in before: self.assertTrue(b['submitted'] < tic) for a in after: self.assertTrue(a['submitted'] >= tic) same = self.db.find_records({'submitted' : tic}) for s in same: self.assertTrue(s['submitted'] == tic) def test_find_records_keys(self): """test extracting subset of record keys""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed'])) def test_find_records_msg_id(self): """ensure msg_id is always in found records""" found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted']) for rec in found: self.assertTrue('msg_id' in rec.keys()) found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id']) for rec in found: self.assertTrue('msg_id' in rec.keys()) def test_find_records_in(self): """test finding records with '$in','$nin' operators""" hist = self.db.get_history() even = hist[::2] odd = hist[1::2] recs = self.db.find_records({ 'msg_id' : {'$in' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(even), set(found)) recs = self.db.find_records({ 'msg_id' : {'$nin' : even}}) found = [ r['msg_id'] for r in recs ] self.assertEqual(set(odd), set(found)) def test_get_history(self): msg_ids = self.db.get_history() latest = datetime(1984,1,1) for msg_id in msg_ids: rec = self.db.get_record(msg_id) newt = rec['submitted'] self.assertTrue(newt >= latest) latest = newt msg_id = self.load_records(1)[-1] self.assertEqual(self.db.get_history()[-1],msg_id) def test_datetime(self): """get/set timestamps with datetime objects""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['submitted'], datetime)) self.db.update_record(msg_id, dict(completed=datetime.now())) rec = self.db.get_record(msg_id) self.assertTrue(isinstance(rec['completed'], datetime)) def test_drop_matching(self): msg_ids = self.load_records(10) query = {'msg_id' : {'$in':msg_ids}} self.db.drop_matching_records(query) recs = self.db.find_records(query) self.assertEqual(len(recs), 0) def test_null(self): """test None comparison queries""" msg_ids = self.load_records(10) query = {'msg_id' : None} recs = self.db.find_records(query) self.assertEqual(len(recs), 0) query = {'msg_id' : {'$ne' : None}} recs = self.db.find_records(query) self.assertTrue(len(recs) >= 10) def test_pop_safe_get(self): """editing query results shouldn't affect record [get]""" msg_id = self.db.get_history()[-1] rec = self.db.get_record(msg_id) rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.get_record(msg_id) self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find(self): """editing query results shouldn't affect record [find]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id})[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id) def test_pop_safe_find_keys(self): """editing query results shouldn't affect record [find+keys]""" msg_id = self.db.get_history()[-1] rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0] rec.pop('buffers') rec['garbage'] = 'hello' rec['header']['msg_id'] = 'fubar' rec2 = self.db.find_records({'msg_id' : msg_id})[0] self.assertTrue('buffers' in rec2) self.assertFalse('garbage' in rec2) self.assertEqual(rec2['header']['msg_id'], msg_id)
def test_serialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ] bmsg = serialize_binary_message(msg) assert isinstance(bmsg, bytes)
def test_serialize_binary(): s = Session() msg = s.msg('data_pub', content={'a': 'b'}) msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ] bmsg = serialize_binary_message(msg) nt.assert_is_instance(bmsg, bytes)
class ExecutingKernelClient(Logging): EXCHANGE = 'remote_notebook_kernel' EXECUTION_PUBLISHING_TOPIC = 'execution.{kernel_id}.from_external' EXECUTION_SUBSCRIPTION_TOPIC = 'execution.{kernel_id}.to_external' def __init__(self, kernel_id, signature_key, executing_kernel_client_settings): super(ExecutingKernelClient, self).__init__() self.client_settings = executing_kernel_client_settings self.kernel_id = kernel_id self.context = zmq.Context() self.session = Session(key=signature_key) self.subscriber = {} self._rabbit_sender_client, self._rabbit_listener = self._init_rabbit_clients( ) self._socket_forwarders = self._init_socket_forwarders() def start(self): self._rabbit_listener.subscribe( topic=self.EXECUTION_SUBSCRIPTION_TOPIC.format( kernel_id=self.kernel_id), handler=self._handle_execution_message_from_rabbit) for forwarder in self._socket_forwarders.itervalues(): forwarder.start() self._init_kernel() def _init_kernel(self): connection_dict = self.get_connection_file_dict() kernel_name = connection_dict['kernel_name'] workflow_id, node_id, port_number, dataframe_storage_type = self.client_settings.dataframe_source gateway_host, gateway_port = self.client_settings.gateway_address r_backend_host, r_backend_port = self.client_settings.r_backend_address # The following work both in Python and R self._execute_code('workflow_id = "{}"'.format(workflow_id)) self._execute_code('node_id = {}'.format( '"{}"'.format(node_id) if node_id is not None else None)) self._execute_code( 'dataframe_storage_type = "{}"'.format(dataframe_storage_type)) self._execute_code('port_number = {}'.format(port_number)) self._execute_code('gateway_address = "{}"'.format(gateway_host)) self._execute_code('gateway_port = {}'.format(gateway_port)) self._execute_code('r_backend_host = "{}"'.format(r_backend_host)) self._execute_code('r_backend_port = {}'.format(r_backend_port)) if kernel_name == 'PythonExecutingKernel': self._execute_file( os.path.join(os.getcwd(), 'executing_kernels/python/kernel_init.py')) elif kernel_name == 'RExecutingKernel': self._execute_file( os.path.join(os.getcwd(), 'executing_kernels/r/kernel_init.R')) def _send_zmq_forward_to_rabbit(self, stream_name, message): if not isinstance(message, list): self._exit( 'ExecutingKernel::_send_zmq_forward_to_rabbit: Malformed message' ) self._rabbit_sender_client.send({ 'type': 'zmq_socket_forward', 'stream': stream_name, 'body': [base64.b64encode(s) for s in message] }) def _handle_execution_message_from_rabbit(self, message): known_message_types = ['zmq_socket_forward'] if not isinstance(message, dict) or 'type' not in message \ or message['type'] not in known_message_types: self._exit( 'ExecutingKernel::_handle_execution_message_from_rabbit: Unknown message: {}' .format(message)) if message['type'] == 'zmq_socket_forward': if 'stream' not in message or 'body' not in message: self._exit( 'ExecutingKernel::_handle_execution_message_from_rabbit: Malformed message: {}' .format(message)) self.logger.debug('Sending to {}'.format(message['stream'])) body = [base64.b64decode(s) for s in message['body']] self._socket_forwarders[message['stream']].forward_to_zmq(body) def _execute_code(self, code): content = dict(code=code, silent=True, user_variables=[], user_expressions={}, allow_stdin=False) msg = self.session.msg('execute_request', content) ser = self.session.serialize(msg) self._socket_forwarders['shell'].forward_to_zmq(ser) def _execute_file(self, filename): with open(filename, 'r') as f: self._execute_code(f.read()) def _exit(self, msg): self.logger.debug(msg) def _init_rabbit_clients(self): rabbit_client = RabbitMQClient( address=self.client_settings.rabbit_mq_address, credentials=self.client_settings.rabbit_mq_credentials, exchange=self.EXCHANGE) sender = RabbitMQJsonSender( rabbit_mq_client=rabbit_client, topic=self.EXECUTION_PUBLISHING_TOPIC.format( kernel_id=self.kernel_id)) listener = RabbitMQJsonReceiver(rabbit_client) return sender, listener def get_connection_file_dict(self): self.logger.debug('Reading connection file {}'.format(os.getcwd())) try: with open('kernel-' + self.kernel_id + '.json', 'r') as json_file: return json.load(json_file) except IOError as e: self.logger.error(os.strerror(e.errno)) raise def _init_socket_forwarders(self): forwarders = {} kernel_json = self.get_connection_file_dict() def make_sender(stream_name): def sender(message): self._send_zmq_forward_to_rabbit(stream_name, message) return sender for socket in ['shell', 'control', 'stdin']: self.subscriber[socket] = self.context.socket(zmq.DEALER) # iopub is PUB socket, we treat it differently and have to set SUBSCRIPTION topic self.subscriber['iopub'] = self.context.socket(zmq.SUB) self.subscriber['iopub'].setsockopt(zmq.SUBSCRIBE, b'') for (socket, zmq_socket) in self.subscriber.iteritems(): zmq_socket.connect('tcp://localhost:' + str(kernel_json[socket + '_port'])) forwarders[socket] = SocketForwarder( stream_name=socket, zmq_socket=zmq_socket, to_rabbit_sender=make_sender(socket)) return forwarders
class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler): session = None gateway = None kernel_id = None ping_callback = None def check_origin(self, origin=None): return JupyterHandler.check_origin(self, origin) def set_default_headers(self): """Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets""" pass def get_compression_options(self): # use deflate compress websocket return {} def authenticate(self): """Run before finishing the GET request Extend this method to add logic that should fire before the websocket finishes completing. """ # authenticate the request before opening the websocket if self.get_current_user() is None: self.log.warning("Couldn't authenticate WebSocket connection") raise web.HTTPError(403) if self.get_argument('session_id', False): self.session.session = cast_unicode( self.get_argument('session_id')) else: self.log.warning("No session ID specified") def initialize(self): self.log.debug("Initializing websocket connection %s", self.request.path) self.session = Session(config=self.config) self.gateway = GatewayWebSocketClient( gateway_url=GatewayClient.instance().url) async def get(self, kernel_id, *args, **kwargs): self.authenticate() self.kernel_id = cast_unicode(kernel_id, 'ascii') await super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs) def send_ping(self): if self.ws_connection is None and self.ping_callback is not None: self.ping_callback.stop() return self.ping(b'') def on_kernel_restart(self): """Inform client about kernel restart""" msg = self.session.msg("status", {'execution_state': 'restarting'}) msg['channel'] = 'iopub' self.write_message(json.dumps(msg, default=date_default)) def open(self, kernel_id, *args, **kwargs): """Handle web socket connection open to notebook server and delegate to gateway web socket handler """ self.ping_callback = PeriodicCallback( self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000) self.ping_callback.start() self.kernel_manager.add_restart_callback(kernel_id, self.on_kernel_restart) self.gateway.on_open( kernel_id=kernel_id, message_callback=self.write_message, compression_options=self.get_compression_options()) def on_message(self, message): """Forward message to gateway web socket handler.""" self.gateway.on_message(message) def write_message(self, message, binary=False): """Send message back to notebook client. This is called via callback from self.gateway._read_messages.""" if self.ws_connection: # prevent WebSocketClosedError if isinstance(message, bytes): binary = True super(WebSocketChannelsHandler, self).write_message(message, binary=binary) elif self.log.isEnabledFor(logging.DEBUG): msg_summary = WebSocketChannelsHandler._get_message_summary( json_decode(utf8(message))) self.log.debug( "Notebook client closed websocket connection - message dropped: {}" .format(msg_summary)) def on_close(self): self.log.debug("Closing websocket connection %s", self.request.path) self.gateway.on_close() super(WebSocketChannelsHandler, self).on_close() @staticmethod def _get_message_summary(message): summary = [] message_type = message['msg_type'] summary.append('type: {}'.format(message_type)) if message_type == 'status': summary.append(', state: {}'.format( message['content']['execution_state'])) elif message_type == 'error': summary.append(', {}:{}:{}'.format( message['content']['ename'], message['content']['evalue'], message['content']['traceback'])) else: summary.append(', ...') # don't display potentially sensitive data return ''.join(summary)