예제 #1
0
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg['header'].pop('version')
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted['header']['version'], V4toV5.version)
예제 #2
0
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg['header'].pop('version')
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted['header']['version'], V4toV5.version)
예제 #3
0
def test_deserialize_binary():
    s = Session()
    msg = s.msg("data_pub", content={"a": "b"})
    msg["buffers"] = [memoryview(os.urandom(2)) for i in range(3)]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    assert msg2 == msg
예제 #4
0
def test_deserialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    assert msg2 == msg
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg["header"].pop("version")
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    assert adapted["header"]["version"] == V4toV5.version
예제 #6
0
def test_deserialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    nt.assert_equal(msg2, msg)
예제 #7
0
class TaskDBTest:
    def setUp(self):
        self.session = Session()
        self.db = self.create_db()
        self.load_records(16)
    
    def create_db(self):
        raise NotImplementedError
    
    def load_records(self, n=1, buffer_size=100):
        """load n records for testing"""
        #sleep 1/10 s, to ensure timestamp is different to previous calls
        time.sleep(0.1)
        msg_ids = []
        for i in range(n):
            msg = self.session.msg('apply_request', content=dict(a=5))
            msg['buffers'] = [os.urandom(buffer_size)]
            rec = init_record(msg)
            msg_id = msg['header']['msg_id']
            msg_ids.append(msg_id)
            self.db.add_record(msg_id, rec)
        return msg_ids
    
    def test_add_record(self):
        before = self.db.get_history()
        self.load_records(5)
        after = self.db.get_history()
        self.assertEqual(len(after), len(before)+5)
        self.assertEqual(after[:-5],before)
        
    def test_drop_record(self):
        msg_id = self.load_records()[-1]
        rec = self.db.get_record(msg_id)
        self.db.drop_record(msg_id)
        self.assertRaises(KeyError,self.db.get_record, msg_id)
    
    def _round_to_millisecond(self, dt):
        """necessary because mongodb rounds microseconds"""
        micro = dt.microsecond
        extra = int(str(micro)[-3:])
        return dt - timedelta(microseconds=extra)
    
    def test_update_record(self):
        now = self._round_to_millisecond(util.utcnow())
        msg_id = self.db.get_history()[-1]
        rec1 = self.db.get_record(msg_id)
        data = {'stdout': 'hello there', 'completed' : now}
        self.db.update_record(msg_id, data)
        rec2 = self.db.get_record(msg_id)
        self.assertEqual(rec2['stdout'], 'hello there')
        self.assertEqual(rec2['completed'], now)
        rec1.update(data)
        self.assertEqual(rec1, rec2)
    
    # def test_update_record_bad(self):
    #     """test updating nonexistant records"""
    #     msg_id = str(uuid.uuid4())
    #     data = {'stdout': 'hello there'}
    #     self.assertRaises(KeyError, self.db.update_record, msg_id, data)

    def test_find_records_dt(self):
        """test finding records by date"""
        hist = self.db.get_history()
        middle = self.db.get_record(hist[len(hist)//2])
        tic = middle['submitted']
        before = self.db.find_records({'submitted' : {'$lt' : tic}})
        after = self.db.find_records({'submitted' : {'$gte' : tic}})
        self.assertEqual(len(before)+len(after),len(hist))
        for b in before:
            self.assertLess(b['submitted'], tic)
        for a in after:
            self.assertGreaterEqual(a['submitted'], tic)
        same = self.db.find_records({'submitted' : tic})
        for s in same:
            self.assertEqual(s['submitted'], tic)
    
    def test_find_records_keys(self):
        """test extracting subset of record keys"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
    
    def test_find_records_msg_id(self):
        """ensure msg_id is always in found records"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
    
    def test_find_records_in(self):
        """test finding records with '$in','$nin' operators"""
        hist = self.db.get_history()
        even = hist[::2]
        odd = hist[1::2]
        recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(even), set(found))
        recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(odd), set(found))
    
    def test_get_history(self):
        msg_ids = self.db.get_history()
        latest = datetime(1984,1,1).replace(tzinfo=utc)
        for msg_id in msg_ids:
            rec = self.db.get_record(msg_id)
            newt = rec['submitted']
            self.assertTrue(newt >= latest)
            latest = newt
        msg_id = self.load_records(1)[-1]
        self.assertEqual(self.db.get_history()[-1],msg_id)
    
    def test_datetime(self):
        """get/set timestamps with datetime objects"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['submitted'], datetime))
        self.db.update_record(msg_id, dict(completed=util.utcnow()))
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['completed'], datetime))

    def test_drop_matching(self):
        msg_ids = self.load_records(10)
        query = {'msg_id' : {'$in':msg_ids}}
        self.db.drop_matching_records(query)
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)
    
    def test_null(self):
        """test None comparison queries"""
        msg_ids = self.load_records(10)

        query = {'msg_id' : None}
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

        query = {'msg_id' : {'$ne' : None}}
        recs = self.db.find_records(query)
        self.assertTrue(len(recs) >= 10)
    
    def test_pop_safe_get(self):
        """editing query results shouldn't affect record [get]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.get_record(msg_id)
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
    
    def test_pop_safe_find(self):
        """editing query results shouldn't affect record [find]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id})[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find_keys(self):
        """editing query results shouldn't affect record [find+keys]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
예제 #8
0
def test_serialize_binary():
    s = Session()
    msg = s.msg("data_pub", content={"a": "b"})
    msg["buffers"] = [memoryview(os.urandom(3)) for i in range(3)]
    bmsg = serialize_binary_message(msg)
    assert isinstance(bmsg, bytes)
예제 #9
0
class TaskDBTest:
    def setUp(self):
        self.session = Session()
        self.db = self.create_db()
        self.load_records(16)
    
    def create_db(self):
        raise NotImplementedError
    
    def load_records(self, n=1, buffer_size=100):
        """load n records for testing"""
        #sleep 1/10 s, to ensure timestamp is different to previous calls
        time.sleep(0.1)
        msg_ids = []
        for i in range(n):
            msg = self.session.msg('apply_request', content=dict(a=5))
            msg['buffers'] = [os.urandom(buffer_size)]
            rec = init_record(msg)
            msg_id = msg['header']['msg_id']
            msg_ids.append(msg_id)
            self.db.add_record(msg_id, rec)
        return msg_ids
    
    def test_add_record(self):
        before = self.db.get_history()
        self.load_records(5)
        after = self.db.get_history()
        self.assertEqual(len(after), len(before)+5)
        self.assertEqual(after[:-5],before)
        
    def test_drop_record(self):
        msg_id = self.load_records()[-1]
        rec = self.db.get_record(msg_id)
        self.db.drop_record(msg_id)
        self.assertRaises(KeyError,self.db.get_record, msg_id)
    
    def _round_to_millisecond(self, dt):
        """necessary because mongodb rounds microseconds"""
        micro = dt.microsecond
        extra = int(str(micro)[-3:])
        return dt - timedelta(microseconds=extra)
    
    def test_update_record(self):
        now = self._round_to_millisecond(datetime.now())
        # 
        msg_id = self.db.get_history()[-1]
        rec1 = self.db.get_record(msg_id)
        data = {'stdout': 'hello there', 'completed' : now}
        self.db.update_record(msg_id, data)
        rec2 = self.db.get_record(msg_id)
        self.assertEqual(rec2['stdout'], 'hello there')
        self.assertEqual(rec2['completed'], now)
        rec1.update(data)
        self.assertEqual(rec1, rec2)
    
    # def test_update_record_bad(self):
    #     """test updating nonexistant records"""
    #     msg_id = str(uuid.uuid4())
    #     data = {'stdout': 'hello there'}
    #     self.assertRaises(KeyError, self.db.update_record, msg_id, data)

    def test_find_records_dt(self):
        """test finding records by date"""
        hist = self.db.get_history()
        middle = self.db.get_record(hist[len(hist)//2])
        tic = middle['submitted']
        before = self.db.find_records({'submitted' : {'$lt' : tic}})
        after = self.db.find_records({'submitted' : {'$gte' : tic}})
        self.assertEqual(len(before)+len(after),len(hist))
        for b in before:
            self.assertTrue(b['submitted'] < tic)
        for a in after:
            self.assertTrue(a['submitted'] >= tic)
        same = self.db.find_records({'submitted' : tic})
        for s in same:
            self.assertTrue(s['submitted'] == tic)
    
    def test_find_records_keys(self):
        """test extracting subset of record keys"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
    
    def test_find_records_msg_id(self):
        """ensure msg_id is always in found records"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
    
    def test_find_records_in(self):
        """test finding records with '$in','$nin' operators"""
        hist = self.db.get_history()
        even = hist[::2]
        odd = hist[1::2]
        recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(even), set(found))
        recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(odd), set(found))
    
    def test_get_history(self):
        msg_ids = self.db.get_history()
        latest = datetime(1984,1,1)
        for msg_id in msg_ids:
            rec = self.db.get_record(msg_id)
            newt = rec['submitted']
            self.assertTrue(newt >= latest)
            latest = newt
        msg_id = self.load_records(1)[-1]
        self.assertEqual(self.db.get_history()[-1],msg_id)
    
    def test_datetime(self):
        """get/set timestamps with datetime objects"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['submitted'], datetime))
        self.db.update_record(msg_id, dict(completed=datetime.now()))
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['completed'], datetime))

    def test_drop_matching(self):
        msg_ids = self.load_records(10)
        query = {'msg_id' : {'$in':msg_ids}}
        self.db.drop_matching_records(query)
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)
    
    def test_null(self):
        """test None comparison queries"""
        msg_ids = self.load_records(10)

        query = {'msg_id' : None}
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

        query = {'msg_id' : {'$ne' : None}}
        recs = self.db.find_records(query)
        self.assertTrue(len(recs) >= 10)
    
    def test_pop_safe_get(self):
        """editing query results shouldn't affect record [get]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.get_record(msg_id)
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
    
    def test_pop_safe_find(self):
        """editing query results shouldn't affect record [find]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id})[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find_keys(self):
        """editing query results shouldn't affect record [find+keys]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
예제 #10
0
def test_serialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    assert isinstance(bmsg, bytes)
예제 #11
0
def test_serialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    nt.assert_is_instance(bmsg, bytes)
예제 #12
0
class ExecutingKernelClient(Logging):
    EXCHANGE = 'remote_notebook_kernel'
    EXECUTION_PUBLISHING_TOPIC = 'execution.{kernel_id}.from_external'
    EXECUTION_SUBSCRIPTION_TOPIC = 'execution.{kernel_id}.to_external'

    def __init__(self, kernel_id, signature_key,
                 executing_kernel_client_settings):
        super(ExecutingKernelClient, self).__init__()
        self.client_settings = executing_kernel_client_settings
        self.kernel_id = kernel_id
        self.context = zmq.Context()
        self.session = Session(key=signature_key)
        self.subscriber = {}

        self._rabbit_sender_client, self._rabbit_listener = self._init_rabbit_clients(
        )
        self._socket_forwarders = self._init_socket_forwarders()

    def start(self):
        self._rabbit_listener.subscribe(
            topic=self.EXECUTION_SUBSCRIPTION_TOPIC.format(
                kernel_id=self.kernel_id),
            handler=self._handle_execution_message_from_rabbit)

        for forwarder in self._socket_forwarders.itervalues():
            forwarder.start()

        self._init_kernel()

    def _init_kernel(self):
        connection_dict = self.get_connection_file_dict()
        kernel_name = connection_dict['kernel_name']
        workflow_id, node_id, port_number, dataframe_storage_type = self.client_settings.dataframe_source
        gateway_host, gateway_port = self.client_settings.gateway_address
        r_backend_host, r_backend_port = self.client_settings.r_backend_address

        # The following work both in Python and R
        self._execute_code('workflow_id = "{}"'.format(workflow_id))
        self._execute_code('node_id = {}'.format(
            '"{}"'.format(node_id) if node_id is not None else None))
        self._execute_code(
            'dataframe_storage_type = "{}"'.format(dataframe_storage_type))
        self._execute_code('port_number = {}'.format(port_number))
        self._execute_code('gateway_address = "{}"'.format(gateway_host))
        self._execute_code('gateway_port = {}'.format(gateway_port))
        self._execute_code('r_backend_host = "{}"'.format(r_backend_host))
        self._execute_code('r_backend_port = {}'.format(r_backend_port))

        if kernel_name == 'PythonExecutingKernel':
            self._execute_file(
                os.path.join(os.getcwd(),
                             'executing_kernels/python/kernel_init.py'))
        elif kernel_name == 'RExecutingKernel':
            self._execute_file(
                os.path.join(os.getcwd(), 'executing_kernels/r/kernel_init.R'))

    def _send_zmq_forward_to_rabbit(self, stream_name, message):
        if not isinstance(message, list):
            self._exit(
                'ExecutingKernel::_send_zmq_forward_to_rabbit: Malformed message'
            )

        self._rabbit_sender_client.send({
            'type':
            'zmq_socket_forward',
            'stream':
            stream_name,
            'body': [base64.b64encode(s) for s in message]
        })

    def _handle_execution_message_from_rabbit(self, message):
        known_message_types = ['zmq_socket_forward']
        if not isinstance(message, dict) or 'type' not in message \
                or message['type'] not in known_message_types:
            self._exit(
                'ExecutingKernel::_handle_execution_message_from_rabbit: Unknown message: {}'
                .format(message))

        if message['type'] == 'zmq_socket_forward':
            if 'stream' not in message or 'body' not in message:
                self._exit(
                    'ExecutingKernel::_handle_execution_message_from_rabbit: Malformed message: {}'
                    .format(message))

            self.logger.debug('Sending to {}'.format(message['stream']))
            body = [base64.b64decode(s) for s in message['body']]
            self._socket_forwarders[message['stream']].forward_to_zmq(body)

    def _execute_code(self, code):
        content = dict(code=code,
                       silent=True,
                       user_variables=[],
                       user_expressions={},
                       allow_stdin=False)
        msg = self.session.msg('execute_request', content)
        ser = self.session.serialize(msg)

        self._socket_forwarders['shell'].forward_to_zmq(ser)

    def _execute_file(self, filename):
        with open(filename, 'r') as f:
            self._execute_code(f.read())

    def _exit(self, msg):
        self.logger.debug(msg)

    def _init_rabbit_clients(self):
        rabbit_client = RabbitMQClient(
            address=self.client_settings.rabbit_mq_address,
            credentials=self.client_settings.rabbit_mq_credentials,
            exchange=self.EXCHANGE)
        sender = RabbitMQJsonSender(
            rabbit_mq_client=rabbit_client,
            topic=self.EXECUTION_PUBLISHING_TOPIC.format(
                kernel_id=self.kernel_id))
        listener = RabbitMQJsonReceiver(rabbit_client)
        return sender, listener

    def get_connection_file_dict(self):
        self.logger.debug('Reading connection file {}'.format(os.getcwd()))
        try:
            with open('kernel-' + self.kernel_id + '.json', 'r') as json_file:
                return json.load(json_file)
        except IOError as e:
            self.logger.error(os.strerror(e.errno))
            raise

    def _init_socket_forwarders(self):
        forwarders = {}
        kernel_json = self.get_connection_file_dict()

        def make_sender(stream_name):
            def sender(message):
                self._send_zmq_forward_to_rabbit(stream_name, message)

            return sender

        for socket in ['shell', 'control', 'stdin']:
            self.subscriber[socket] = self.context.socket(zmq.DEALER)

        # iopub is PUB socket, we treat it differently and have to set SUBSCRIPTION topic
        self.subscriber['iopub'] = self.context.socket(zmq.SUB)
        self.subscriber['iopub'].setsockopt(zmq.SUBSCRIBE, b'')

        for (socket, zmq_socket) in self.subscriber.iteritems():
            zmq_socket.connect('tcp://localhost:' +
                               str(kernel_json[socket + '_port']))
            forwarders[socket] = SocketForwarder(
                stream_name=socket,
                zmq_socket=zmq_socket,
                to_rabbit_sender=make_sender(socket))

        return forwarders
예제 #13
0
class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):

    session = None
    gateway = None
    kernel_id = None
    ping_callback = None

    def check_origin(self, origin=None):
        return JupyterHandler.check_origin(self, origin)

    def set_default_headers(self):
        """Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets"""
        pass

    def get_compression_options(self):
        # use deflate compress websocket
        return {}

    def authenticate(self):
        """Run before finishing the GET request

        Extend this method to add logic that should fire before
        the websocket finishes completing.
        """
        # authenticate the request before opening the websocket
        if self.get_current_user() is None:
            self.log.warning("Couldn't authenticate WebSocket connection")
            raise web.HTTPError(403)

        if self.get_argument('session_id', False):
            self.session.session = cast_unicode(
                self.get_argument('session_id'))
        else:
            self.log.warning("No session ID specified")

    def initialize(self):
        self.log.debug("Initializing websocket connection %s",
                       self.request.path)
        self.session = Session(config=self.config)
        self.gateway = GatewayWebSocketClient(
            gateway_url=GatewayClient.instance().url)

    async def get(self, kernel_id, *args, **kwargs):
        self.authenticate()
        self.kernel_id = cast_unicode(kernel_id, 'ascii')
        await super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id,
                                                        *args,
                                                        **kwargs)

    def send_ping(self):
        if self.ws_connection is None and self.ping_callback is not None:
            self.ping_callback.stop()
            return

        self.ping(b'')

    def on_kernel_restart(self):
        """Inform client about kernel restart"""
        msg = self.session.msg("status", {'execution_state': 'restarting'})
        msg['channel'] = 'iopub'
        self.write_message(json.dumps(msg, default=date_default))

    def open(self, kernel_id, *args, **kwargs):
        """Handle web socket connection open to notebook server and delegate to gateway web socket handler """
        self.ping_callback = PeriodicCallback(
            self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
        self.ping_callback.start()
        self.kernel_manager.add_restart_callback(kernel_id,
                                                 self.on_kernel_restart)

        self.gateway.on_open(
            kernel_id=kernel_id,
            message_callback=self.write_message,
            compression_options=self.get_compression_options())

    def on_message(self, message):
        """Forward message to gateway web socket handler."""
        self.gateway.on_message(message)

    def write_message(self, message, binary=False):
        """Send message back to notebook client.  This is called via callback from self.gateway._read_messages."""
        if self.ws_connection:  # prevent WebSocketClosedError
            if isinstance(message, bytes):
                binary = True
            super(WebSocketChannelsHandler, self).write_message(message,
                                                                binary=binary)
        elif self.log.isEnabledFor(logging.DEBUG):
            msg_summary = WebSocketChannelsHandler._get_message_summary(
                json_decode(utf8(message)))
            self.log.debug(
                "Notebook client closed websocket connection - message dropped: {}"
                .format(msg_summary))

    def on_close(self):
        self.log.debug("Closing websocket connection %s", self.request.path)
        self.gateway.on_close()
        super(WebSocketChannelsHandler, self).on_close()

    @staticmethod
    def _get_message_summary(message):
        summary = []
        message_type = message['msg_type']
        summary.append('type: {}'.format(message_type))

        if message_type == 'status':
            summary.append(', state: {}'.format(
                message['content']['execution_state']))
        elif message_type == 'error':
            summary.append(', {}:{}:{}'.format(
                message['content']['ename'], message['content']['evalue'],
                message['content']['traceback']))
        else:
            summary.append(', ...')  # don't display potentially sensitive data

        return ''.join(summary)