示例#1
0
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg['header'].pop('version')
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted['header']['version'], V4toV5.version)
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg["header"].pop("version")
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    assert adapted["header"]["version"] == V4toV5.version
def test_deserialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    assert msg2 == msg
示例#4
0
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg['header'].pop('version')
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted['header']['version'], V4toV5.version)
示例#5
0
def test_deserialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(2)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    nt.assert_equal(msg2, msg)
示例#6
0
def test_deserialize_binary():
    s = Session()
    msg = s.msg("data_pub", content={"a": "b"})
    msg["buffers"] = [memoryview(os.urandom(2)) for i in range(3)]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    assert msg2 == msg
示例#7
0
        def __init__(self, *args, **kwargs):
            super(WidgetHandler, self).__init__(*args, **kwargs)

            self.context = zmq.Context()
            self.socket = self.context.socket(zmq.XREQ)
            self.socket.linger = 1000
            self.socket.connect("tcp://127.0.0.1:%s" % port)
            self.session = Session(key=session_key.encode())
def main(connection_file):
    """watch iopub channel, and print messages"""

    ctx = zmq.Context.instance()

    with open(connection_file) as f:
        cfg = json.loads(f.read())


#    reg_url = cfg['interface']
    reg_url = 'tcp://140.117.168.49'
    iopub_port = cfg['iopub']
    iopub_url = "{}:{}".format(reg_url, iopub_port)
    print("iopub_url:", iopub_url)

    session = Session(key=cfg['key'].encode('ascii'))
    sub = ctx.socket(zmq.SUB)

    # This will subscribe to all messages:
    sub.SUBSCRIBE = b''
    # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout
    # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes
    # to everything from engine 1, but there is no way to subscribe to
    # just stdout from everyone.
    # multiple calls to subscribe will add subscriptions, e.g. to subscribe to
    # engine 1's stderr and engine 2's stdout:
    # sub.SUBSCRIBE = b'engine.1.stderr'
    # sub.SUBSCRIBE = b'engine.2.stdout'
    sub.connect(iopub_url)
    while True:
        try:
            idents, msg = session.recv(sub, mode=0)
        except KeyboardInterrupt:
            return
        # ident always length 1 here
        topic = idents[0].decode('utf8', 'replace')
        if msg['msg_type'] == 'stream':
            # stdout/stderr
            # stream names are in msg['content']['name'], if you want to handle
            # them differently
            print("{}: {}".format(topic, msg['content']['text']))
        elif msg['msg_type'] == 'error':
            # Python traceback
            c = msg['content']
            print(topic + ':')
            for line in c['traceback']:
                # indent lines
                print('    ' + line)
        elif msg['msg_type'] == 'error':
            # Python traceback
            c = msg['content']
            print(topic + ':')
            for line in c['traceback']:
                # indent lines
                print('    ' + line)
    def __init__(self, kernel_id, signature_key,
                 executing_kernel_client_settings):
        super(ExecutingKernelClient, self).__init__()
        self.client_settings = executing_kernel_client_settings
        self.kernel_id = kernel_id
        self.context = zmq.Context()
        self.session = Session(key=signature_key)
        self.subscriber = {}

        self._rabbit_sender_client, self._rabbit_listener = self._init_rabbit_clients(
        )
        self._socket_forwarders = self._init_socket_forwarders()
示例#10
0
def main(connection_file):
    """watch iopub channel, and print messages"""

    ctx = zmq.Context.instance()

    with open(connection_file) as f:
        cfg = json.loads(f.read())

    reg_url = cfg['interface']
    iopub_port = cfg['iopub']
    iopub_url = "%s:%s"%(reg_url, iopub_port)

    session = Session(key=cfg['key'].encode('ascii'))
    sub = ctx.socket(zmq.SUB)

    # This will subscribe to all messages:
    sub.SUBSCRIBE = b''
    # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout
    # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes
    # to everything from engine 1, but there is no way to subscribe to
    # just stdout from everyone.
    # multiple calls to subscribe will add subscriptions, e.g. to subscribe to
    # engine 1's stderr and engine 2's stdout:
    # sub.SUBSCRIBE = b'engine.1.stderr'
    # sub.SUBSCRIBE = b'engine.2.stdout'
    sub.connect(iopub_url)
    while True:
        try:
            idents,msg = session.recv(sub, mode=0)
        except KeyboardInterrupt:
            return
        # ident always length 1 here
        topic = idents[0].decode('utf8', 'replace')
        if msg['msg_type'] == 'stream':
            # stdout/stderr
            # stream names are in msg['content']['name'], if you want to handle
            # them differently
            print("%s: %s" % (topic, msg['content']['text']))
        elif msg['msg_type'] == 'error':
            # Python traceback
            c = msg['content']
            print(topic + ':')
            for line in c['traceback']:
                # indent lines
                print('    ' + line)
        elif msg['msg_type'] == 'error':
            # Python traceback
            c = msg['content']
            print(topic + ':')
            for line in c['traceback']:
                # indent lines
                print('    ' + line)
示例#11
0
    def load_connection_file(self):
        """load config from a JSON connector file,
        at a *lower* priority than command-line/config files.

        Same content can be specified in $IPP_CONNECTION_INFO env
        """
        config = self.config

        if self.connection_info_env:
            self.log.info("Loading connection info from $IPP_CONNECTION_INFO")
            d = json.loads(self.connection_info_env)
        else:
            self.log.info("Loading connection file %r", self.url_file)
            with open(self.url_file) as f:
                d = json.load(f)

        # allow hand-override of location for disambiguation
        # and ssh-server
        if 'IPEngine.location' not in self.cli_config:
            self.location = d['location']
        if 'ssh' in d and not self.sshserver:
            self.sshserver = d.get("ssh")

        proto, ip = d['interface'].split('://')
        ip = disambiguate_ip_address(ip, self.location)
        d['interface'] = f'{proto}://{ip}'

        if d.get('curve_serverkey'):
            # connection file takes precedence over env, if present and defined
            self.curve_serverkey = d['curve_serverkey'].encode('ascii')
        if self.curve_serverkey:
            self.log.info("Using CurveZMQ security")
            self._ensure_curve_keypair()
        else:
            self.log.warning("Not using CurveZMQ security")

        # DO NOT allow override of basic URLs, serialization, or key
        # JSON file takes top priority there
        if d.get('key') or 'key' not in config.Session:
            config.Session.key = d.get('key', '').encode('utf8')
        config.Session.signature_scheme = d['signature_scheme']

        self.registration_url = f"{d['interface']}:{d['registration']}"

        config.Session.packer = d['pack']
        config.Session.unpacker = d['unpack']
        self.session = Session(parent=self)

        self.log.debug("Config changed:")
        self.log.debug("%r", config)
        self.connection_info = d
def test_load_connection_file_session_with_kn():
    """test load_connection_file() after """
    session = Session()
    app = DummyConsoleApp(session=Session())
    app.initialize(argv=[])
    session = app.session

    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info_kn)
        app.connection_file = cf
        app.load_connection_file()

    assert session.key == sample_info_kn['key']
    assert session.signature_scheme == sample_info_kn['signature_scheme']
示例#13
0
def test_io_api():
    """Test that wrapped stdout has the same API as a normal TextIO object"""
    session = Session()
    ctx = zmq.Context()
    pub = ctx.socket(zmq.PUB)
    thread = IOPubThread(pub)
    thread.start()

    stream = OutStream(session, thread, 'stdout')

    # cleanup unused zmq objects before we start testing
    thread.stop()
    thread.close()
    ctx.term()

    assert stream.errors is None
    assert not stream.isatty()
    with nt.assert_raises(io.UnsupportedOperation):
        stream.detach()
    with nt.assert_raises(io.UnsupportedOperation):
        next(stream)
    with nt.assert_raises(io.UnsupportedOperation):
        stream.read()
    with nt.assert_raises(io.UnsupportedOperation):
        stream.readline()
    with nt.assert_raises(io.UnsupportedOperation):
        stream.seek()
    with nt.assert_raises(io.UnsupportedOperation):
        stream.tell()
    def start_watching_activity(self, kernel_id):
        """Start watching IOPub messages on a kernel for activity.
        
        - update last_activity on every message
        - record execution_state from status messages
        """
        kernel = self._kernels[kernel_id]
        # add busy/activity markers:
        kernel.execution_state = 'starting'
        kernel.last_activity = utcnow()
        kernel._activity_stream = kernel.connect_iopub()
        session = Session(
            config=kernel.session.config,
            key=kernel.session.key,
        )

        def record_activity(msg_list):
            """Record an IOPub message arriving from a kernel"""
            kernel.last_activity = utcnow()

            idents, fed_msg_list = session.feed_identities(msg_list)
            msg = session.deserialize(fed_msg_list)

            msg_type = msg['header']['msg_type']
            self.log.debug("activity on %s: %s", kernel_id, msg_type)
            if msg_type == 'status':
                kernel.execution_state = msg['content']['execution_state']

        kernel._activity_stream.on_recv(record_activity)
示例#15
0
        def __init__(self, *args, **kwargs):
            super(WidgetHandler, self).__init__(*args, **kwargs)

            self.context = zmq.Context()
            self.socket = self.context.socket(zmq.XREQ)
            self.socket.linger = 1000
            self.socket.connect("tcp://127.0.0.1:%s" % port)
            self.session = Session(key=session_key.encode())
def test_io_isatty():
    session = Session()
    ctx = zmq.Context()
    pub = ctx.socket(zmq.PUB)
    thread = IOPubThread(pub)
    thread.start()

    stream = OutStream(session, thread, 'stdout', isatty=True)
    assert stream.isatty()
示例#17
0
    class WidgetHandler(FileSystemEventHandler):
        def __init__(self, *args, **kwargs):
            super(WidgetHandler, self).__init__(*args, **kwargs)

            self.context = zmq.Context()
            self.socket = self.context.socket(zmq.XREQ)
            self.socket.linger = 1000
            self.socket.connect("tcp://127.0.0.1:%s" % port)
            self.session = Session(key=session_key.encode())

        def on_any_event(self, event):
            msg = self.msg(event)
            self.session.send(self.socket, msg)

        def msg(self, event):
            return {
                "buffers": [],
                "channel": "shell",
                "metadata": {},
                "parent_header": {},
                "header": {
                    "msg_id": str(uuid4()),
                    "username": "******",
                    "session": session,
                    "msg_type": "comm_msg",
                    "version": "5.0"
                },
                "content": {
                    "comm_id": model_id,
                    "data": {
                        "method": "custom",
                        "content": {
                            "event": event.event_type,
                            "is_directory": event.is_directory,
                            "src_path": event.src_path,
                            "dest_path": getattr(event,
                                                 "dest_path",
                                                 None)
                        }
                    },
                },
            }
示例#18
0
    def _create_session(self):
        from jupyter_client.session import Session
        try:
            from jupyter_client.session import new_id_bytes
        except ImportError:

            def new_id_bytes():
                import uuid
                return uuid.uuid4()

        self._session = Session(username=u'kernel', key=new_id_bytes())
示例#19
0
def create_shell(username, session_id, key):
    """Instantiates a CapturingSocket and SwiftShell and hooks them up.
    
    After you call this, the returned CapturingSocket should capture all
    IPython display messages.
    """
    socket = CapturingSocket()
    session = Session(username=username, session=session_id, key=key)
    shell = SwiftShell.instance()
    shell.display_pub.session = session
    shell.display_pub.pub_socket = socket
    return (socket, shell)
示例#20
0
    class WidgetHandler(FileSystemEventHandler):
        def __init__(self, *args, **kwargs):
            super(WidgetHandler, self).__init__(*args, **kwargs)

            self.context = zmq.Context()
            self.socket = self.context.socket(zmq.XREQ)
            self.socket.linger = 1000
            self.socket.connect("tcp://127.0.0.1:%s" % port)
            self.session = Session(key=session_key.encode())

        def on_any_event(self, event):
            msg = self.msg(event)
            self.session.send(self.socket, msg)

        def msg(self, event):
            return {
                "buffers": [],
                "channel": "shell",
                "metadata": {},
                "parent_header": {},
                "header": {
                    "msg_id": str(uuid4()),
                    "username": "******",
                    "session": session,
                    "msg_type": "comm_msg",
                    "version": "5.0"
                },
                "content": {
                    "comm_id": model_id,
                    "data": {
                        "method": "custom",
                        "content": {
                            "event": event.event_type,
                            "is_directory": event.is_directory,
                            "src_path": event.src_path,
                            "dest_path": getattr(event, "dest_path", None)
                        }
                    },
                },
            }
示例#21
0
 def _setup_session(self, reply, comp_id):
     """
     Set up the kernel information contained in the untrusted reply message `reply` from computer `comp_id`.
     """
     reply_content = reply["content"]
     kernel_id = reply_content["kernel_id"]
     kernel_connection = reply_content["connection"]
     self._kernels[kernel_id] = {
         "comp_id": comp_id,
         "connection": kernel_connection,
         "executing": 0
     }
     self._comps[comp_id]["kernels"][kernel_id] = None
     self._sessions[kernel_id] = Session(key=kernel_connection["key"])
示例#22
0
    def __init__(
        self,
        *,
        pid: int,
        engine_id: int,
        control_url: str,
        registration_url: str,
        identity: bytes,
        curve_serverkey: bytes,
        curve_publickey: bytes,
        curve_secretkey: bytes,
        config: Config,
        pipe,
        log_level: int = logging.INFO,
    ):
        self.pid = pid
        self.engine_id = engine_id
        self.parent_process = psutil.Process(self.pid)
        self.control_url = control_url
        self.registration_url = registration_url
        self.identity = identity
        self.curve_serverkey = curve_serverkey
        self.curve_publickey = curve_publickey
        self.curve_secretkey = curve_secretkey
        self.config = config
        self.pipe = pipe
        self.session = Session(config=self.config)

        self.log = local_logger(f"{self.__class__.__name__}.{engine_id}",
                                log_level)
        self.log.propagate = False

        self.control_handlers = {
            "signal_request": self.signal_request,
        }
        self._finish_called = False
示例#23
0
 def _setup_session(self, reply, comp_id, timeout=None):
     """
     Set up the kernel information contained in the untrusted reply message `reply` from computer `comp_id`.
     """
     reply_content = reply["content"]
     kernel_id = reply_content["kernel_id"]
     kernel_connection = reply_content["connection"]
     if timeout is None :
         timeout = self.max_kernel_timeout
     self._kernels[kernel_id] = {"comp_id": comp_id,
                                 "connection": kernel_connection,
                                 "executing": 0, # number of active execute_requests
                                 "deadline": time.time()+timeout,
                                 "timeout": timeout}
     self._comps[comp_id]["kernels"][kernel_id] = None
     self._sessions[kernel_id] = Session(key=kernel_connection["key"])
示例#24
0
    def create_on_reply(self, kernel_id):
        """Creates an anonymous function to handle reply messages from the
        kernel.

        Parameters
        ----------
        kernel_id
            Kernel to listen to

        Returns
        -------
        function
            Callback function taking a kernel ID and 0mq message list
        """
        kernel = self.kernel_clients[kernel_id]
        session = Session(
            config=kernel.session.config,
            key=kernel.session.key,
        )
        return lambda msg_list: self._on_reply(kernel_id, session, msg_list)
示例#25
0
    def start_watching_activity(self, kernel_id):
        """Start watching IOPub messages on a kernel for activity.

        - update last_activity on every message
        - record execution_state from status messages
        """
        kernel = self._kernels[kernel_id]
        # add busy/activity markers:
        kernel.execution_state = "starting"
        kernel.reason = ""
        kernel.last_activity = utcnow()
        kernel._activity_stream = kernel.connect_iopub()
        session = Session(
            config=kernel.session.config,
            key=kernel.session.key,
        )

        def record_activity(msg_list):
            """Record an IOPub message arriving from a kernel"""
            self.last_kernel_activity = kernel.last_activity = utcnow()

            idents, fed_msg_list = session.feed_identities(msg_list)
            msg = session.deserialize(fed_msg_list, content=False)

            msg_type = msg["header"]["msg_type"]
            if msg_type == "status":
                msg = session.deserialize(fed_msg_list)
                kernel.execution_state = msg["content"]["execution_state"]
                self.log.debug(
                    "activity on %s: %s (%s)",
                    kernel_id,
                    msg_type,
                    kernel.execution_state,
                )
            else:
                self.log.debug("activity on %s: %s", kernel_id, msg_type)

        kernel._activity_stream.on_recv(record_activity)
示例#26
0
def test_serialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ memoryview(os.urandom(3)) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    nt.assert_is_instance(bmsg, bytes)
示例#27
0
 def initialize(self):
     self.log.debug("Initializing websocket connection %s",
                    self.request.path)
     self.session = Session(config=self.config)
     self.gateway = GatewayWebSocketClient(
         gateway_url=GatewayClient.instance().url)
示例#28
0
 def setUp(self):
     self.session = Session()
示例#29
0
class TaskDBTest:
    def setUp(self):
        self.session = Session()
        self.db = self.create_db()
        self.load_records(16)
    
    def create_db(self):
        raise NotImplementedError
    
    def load_records(self, n=1, buffer_size=100):
        """load n records for testing"""
        #sleep 1/10 s, to ensure timestamp is different to previous calls
        time.sleep(0.1)
        msg_ids = []
        for i in range(n):
            msg = self.session.msg('apply_request', content=dict(a=5))
            msg['buffers'] = [os.urandom(buffer_size)]
            rec = init_record(msg)
            msg_id = msg['header']['msg_id']
            msg_ids.append(msg_id)
            self.db.add_record(msg_id, rec)
        return msg_ids
    
    def test_add_record(self):
        before = self.db.get_history()
        self.load_records(5)
        after = self.db.get_history()
        self.assertEqual(len(after), len(before)+5)
        self.assertEqual(after[:-5],before)
        
    def test_drop_record(self):
        msg_id = self.load_records()[-1]
        rec = self.db.get_record(msg_id)
        self.db.drop_record(msg_id)
        self.assertRaises(KeyError,self.db.get_record, msg_id)
    
    def _round_to_millisecond(self, dt):
        """necessary because mongodb rounds microseconds"""
        micro = dt.microsecond
        extra = int(str(micro)[-3:])
        return dt - timedelta(microseconds=extra)
    
    def test_update_record(self):
        now = self._round_to_millisecond(datetime.now())
        # 
        msg_id = self.db.get_history()[-1]
        rec1 = self.db.get_record(msg_id)
        data = {'stdout': 'hello there', 'completed' : now}
        self.db.update_record(msg_id, data)
        rec2 = self.db.get_record(msg_id)
        self.assertEqual(rec2['stdout'], 'hello there')
        self.assertEqual(rec2['completed'], now)
        rec1.update(data)
        self.assertEqual(rec1, rec2)
    
    # def test_update_record_bad(self):
    #     """test updating nonexistant records"""
    #     msg_id = str(uuid.uuid4())
    #     data = {'stdout': 'hello there'}
    #     self.assertRaises(KeyError, self.db.update_record, msg_id, data)

    def test_find_records_dt(self):
        """test finding records by date"""
        hist = self.db.get_history()
        middle = self.db.get_record(hist[len(hist)//2])
        tic = middle['submitted']
        before = self.db.find_records({'submitted' : {'$lt' : tic}})
        after = self.db.find_records({'submitted' : {'$gte' : tic}})
        self.assertEqual(len(before)+len(after),len(hist))
        for b in before:
            self.assertTrue(b['submitted'] < tic)
        for a in after:
            self.assertTrue(a['submitted'] >= tic)
        same = self.db.find_records({'submitted' : tic})
        for s in same:
            self.assertTrue(s['submitted'] == tic)
    
    def test_find_records_keys(self):
        """test extracting subset of record keys"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
    
    def test_find_records_msg_id(self):
        """ensure msg_id is always in found records"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
    
    def test_find_records_in(self):
        """test finding records with '$in','$nin' operators"""
        hist = self.db.get_history()
        even = hist[::2]
        odd = hist[1::2]
        recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(even), set(found))
        recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(odd), set(found))
    
    def test_get_history(self):
        msg_ids = self.db.get_history()
        latest = datetime(1984,1,1)
        for msg_id in msg_ids:
            rec = self.db.get_record(msg_id)
            newt = rec['submitted']
            self.assertTrue(newt >= latest)
            latest = newt
        msg_id = self.load_records(1)[-1]
        self.assertEqual(self.db.get_history()[-1],msg_id)
    
    def test_datetime(self):
        """get/set timestamps with datetime objects"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['submitted'], datetime))
        self.db.update_record(msg_id, dict(completed=datetime.now()))
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['completed'], datetime))

    def test_drop_matching(self):
        msg_ids = self.load_records(10)
        query = {'msg_id' : {'$in':msg_ids}}
        self.db.drop_matching_records(query)
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)
    
    def test_null(self):
        """test None comparison queries"""
        msg_ids = self.load_records(10)

        query = {'msg_id' : None}
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

        query = {'msg_id' : {'$ne' : None}}
        recs = self.db.find_records(query)
        self.assertTrue(len(recs) >= 10)
    
    def test_pop_safe_get(self):
        """editing query results shouldn't affect record [get]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.get_record(msg_id)
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
    
    def test_pop_safe_find(self):
        """editing query results shouldn't affect record [find]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id})[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find_keys(self):
        """editing query results shouldn't affect record [find+keys]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
示例#30
0
 def initialize(self):
     self.log.debug("Initializing websocket connection %s",
                    self.request.path)
     self.session = Session(config=self.config)
     # TODO: make kernel client class configurable
     self.gateway = KernelGatewayWSClient()
示例#31
0
 def setUp(self):
     self.session = Session()
     self.db = self.create_db()
     self.load_records(16)
示例#32
0
文件: manager.py 项目: DT021/wau
 def _session_default(self):
     # don't sign in-process messages
     return Session(key=b'', parent=self)
示例#33
0
 def _session_default(self):
     from jupyter_client.session import Session
     return Session(parent=self)
示例#34
0
 def _default_session(self):
     from jupyter_client.session import Session
     return Session(parent=self, key=INPROCESS_KEY)
示例#35
0
 def setUp(self):
     self.session = Session()
     self.db = self.create_db()
     self.load_records(16)
示例#36
0
class TaskDBTest:
    def setUp(self):
        self.session = Session()
        self.db = self.create_db()
        self.load_records(16)
    
    def create_db(self):
        raise NotImplementedError
    
    def load_records(self, n=1, buffer_size=100):
        """load n records for testing"""
        #sleep 1/10 s, to ensure timestamp is different to previous calls
        time.sleep(0.1)
        msg_ids = []
        for i in range(n):
            msg = self.session.msg('apply_request', content=dict(a=5))
            msg['buffers'] = [os.urandom(buffer_size)]
            rec = init_record(msg)
            msg_id = msg['header']['msg_id']
            msg_ids.append(msg_id)
            self.db.add_record(msg_id, rec)
        return msg_ids
    
    def test_add_record(self):
        before = self.db.get_history()
        self.load_records(5)
        after = self.db.get_history()
        self.assertEqual(len(after), len(before)+5)
        self.assertEqual(after[:-5],before)
        
    def test_drop_record(self):
        msg_id = self.load_records()[-1]
        rec = self.db.get_record(msg_id)
        self.db.drop_record(msg_id)
        self.assertRaises(KeyError,self.db.get_record, msg_id)
    
    def _round_to_millisecond(self, dt):
        """necessary because mongodb rounds microseconds"""
        micro = dt.microsecond
        extra = int(str(micro)[-3:])
        return dt - timedelta(microseconds=extra)
    
    def test_update_record(self):
        now = self._round_to_millisecond(util.utcnow())
        msg_id = self.db.get_history()[-1]
        rec1 = self.db.get_record(msg_id)
        data = {'stdout': 'hello there', 'completed' : now}
        self.db.update_record(msg_id, data)
        rec2 = self.db.get_record(msg_id)
        self.assertEqual(rec2['stdout'], 'hello there')
        self.assertEqual(rec2['completed'], now)
        rec1.update(data)
        self.assertEqual(rec1, rec2)
    
    # def test_update_record_bad(self):
    #     """test updating nonexistant records"""
    #     msg_id = str(uuid.uuid4())
    #     data = {'stdout': 'hello there'}
    #     self.assertRaises(KeyError, self.db.update_record, msg_id, data)

    def test_find_records_dt(self):
        """test finding records by date"""
        hist = self.db.get_history()
        middle = self.db.get_record(hist[len(hist)//2])
        tic = middle['submitted']
        before = self.db.find_records({'submitted' : {'$lt' : tic}})
        after = self.db.find_records({'submitted' : {'$gte' : tic}})
        self.assertEqual(len(before)+len(after),len(hist))
        for b in before:
            self.assertLess(b['submitted'], tic)
        for a in after:
            self.assertGreaterEqual(a['submitted'], tic)
        same = self.db.find_records({'submitted' : tic})
        for s in same:
            self.assertEqual(s['submitted'], tic)
    
    def test_find_records_keys(self):
        """test extracting subset of record keys"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
    
    def test_find_records_msg_id(self):
        """ensure msg_id is always in found records"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
    
    def test_find_records_in(self):
        """test finding records with '$in','$nin' operators"""
        hist = self.db.get_history()
        even = hist[::2]
        odd = hist[1::2]
        recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(even), set(found))
        recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(odd), set(found))
    
    def test_get_history(self):
        msg_ids = self.db.get_history()
        latest = datetime(1984,1,1).replace(tzinfo=utc)
        for msg_id in msg_ids:
            rec = self.db.get_record(msg_id)
            newt = rec['submitted']
            self.assertTrue(newt >= latest)
            latest = newt
        msg_id = self.load_records(1)[-1]
        self.assertEqual(self.db.get_history()[-1],msg_id)
    
    def test_datetime(self):
        """get/set timestamps with datetime objects"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['submitted'], datetime))
        self.db.update_record(msg_id, dict(completed=util.utcnow()))
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['completed'], datetime))

    def test_drop_matching(self):
        msg_ids = self.load_records(10)
        query = {'msg_id' : {'$in':msg_ids}}
        self.db.drop_matching_records(query)
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)
    
    def test_null(self):
        """test None comparison queries"""
        msg_ids = self.load_records(10)

        query = {'msg_id' : None}
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

        query = {'msg_id' : {'$ne' : None}}
        recs = self.db.find_records(query)
        self.assertTrue(len(recs) >= 10)
    
    def test_pop_safe_get(self):
        """editing query results shouldn't affect record [get]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.get_record(msg_id)
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
    
    def test_pop_safe_find(self):
        """editing query results shouldn't affect record [find]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id})[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find_keys(self):
        """editing query results shouldn't affect record [find+keys]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
示例#37
0
 def initialize(self):
     self.log.debug("Initializing websocket connection %s",
                    self.request.path)
     self.session = Session(config=self.config)
示例#38
0
 def _default_session(self):
     return Session(parent=self)