コード例 #1
0
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg['header'].pop('version')
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted['header']['version'], V4toV5.version)
コード例 #2
0
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg['header'].pop('version')
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted['header']['version'], V4toV5.version)
コード例 #3
0
ファイル: test_serialize.py プロジェクト: wguo123/ipython
def test_deserialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [memoryview(os.urandom(2)) for i in range(3)]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    nt.assert_equal(msg2, msg)
コード例 #4
0
ファイル: test_adapter.py プロジェクト: ngoldbaum/ipython
def test_default_version():
    s = Session()
    msg = s.msg("msg_type")
    msg["header"].pop("version")
    original = copy.deepcopy(msg)
    adapted = adapt(original)
    nt.assert_equal(adapted["header"]["version"], V4toV5.version)
コード例 #5
0
ファイル: test_serialize.py プロジェクト: 2t7/ipython
def test_deserialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ os.urandom(2) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    msg2 = deserialize_binary_message(bmsg)
    nt.assert_equal(msg2, msg)
コード例 #6
0
ファイル: iopubwatcher.py プロジェクト: 3kwa/ipython
def main(connection_file):
    """watch iopub channel, and print messages"""
    
    ctx = zmq.Context.instance()
    
    with open(connection_file) as f:
        cfg = json.loads(f.read())
    
    location = cfg['location']
    reg_url = cfg['url']
    session = Session(key=str_to_bytes(cfg['exec_key']))
    
    query = ctx.socket(zmq.DEALER)
    query.connect(disambiguate_url(cfg['url'], location))
    session.send(query, "connection_request")
    idents,msg = session.recv(query, mode=0)
    c = msg['content']
    iopub_url = disambiguate_url(c['iopub'], location)
    sub = ctx.socket(zmq.SUB)
    # This will subscribe to all messages:
    sub.setsockopt(zmq.SUBSCRIBE, b'')
    # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout
    # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes
    # to everything from engine 1, but there is no way to subscribe to
    # just stdout from everyone.
    # multiple calls to subscribe will add subscriptions, e.g. to subscribe to
    # engine 1's stderr and engine 2's stdout:
    # sub.setsockopt(zmq.SUBSCRIBE, b'engine.1.stderr')
    # sub.setsockopt(zmq.SUBSCRIBE, b'engine.2.stdout')
    sub.connect(iopub_url)
    while True:
        try:
            idents,msg = session.recv(sub, mode=0)
        except KeyboardInterrupt:
            return
        # ident always length 1 here
        topic = idents[0]
        if msg['msg_type'] == 'stream':
            # stdout/stderr
            # stream names are in msg['content']['name'], if you want to handle
            # them differently
            print("%s: %s" % (topic, msg['content']['data']))
        elif msg['msg_type'] == 'pyerr':
            # Python traceback
            c = msg['content']
            print(topic + ':')
            for line in c['traceback']:
                # indent lines
                print('    ' + line)
コード例 #7
0
def test_load_connection_file_session():
    """test load_connection_file() after """
    session = Session()
    app = DummyConsoleApp(session=Session())
    app.initialize(argv=[])
    session = app.session

    with TemporaryDirectory() as d:
        cf = os.path.join(d, 'kernel.json')
        connect.write_connection_file(cf, **sample_info)
        app.connection_file = cf
        app.load_connection_file()

    nt.assert_equal(session.key, sample_info['key'])
    nt.assert_equal(session.signature_scheme, sample_info['signature_scheme'])
コード例 #8
0
ファイル: hub.py プロジェクト: edwardjkim/stress-proxy
def start_notebook(url, port, user):
    hub_url = 'https://%s:%s/hub' % (url, port)
    user_url = 'https://%s:%s/user/%s' % (url, port, user)

    cookies = login(hub_url, user, user)
    api = NBAPI(url=user_url, cookies=cookies)

    path = 'Hello.ipynb'
    for i in itertools.count():
        gen_log.info("loading %s (%s)", user, i)
        nb = api.get_notebook(path)

        gen_log.info("starting %s (%s)", user, i)
        session = Session()
        kernel = yield api.new_kernel(session.session)

        try:
            for j in range(20):
                gen_log.info("running %s (%s:%s)", user, j, i)
                yield run_notebook(nb, kernel, session)
                yield sleep(0.05)

            gen_log.info("saving %s (%s)", user, i)
            api.save_notebook(nb, path)

        finally:
            api.kill_kernel(kernel['id'])

    gen_log.info("history: %s", response.history)
コード例 #9
0
ファイル: iopubwatcher.py プロジェクト: prashkr/ipython
def main(connection_file):
    """watch iopub channel, and print messages"""

    ctx = zmq.Context.instance()

    with open(connection_file) as f:
        cfg = json.loads(f.read())

    reg_url = cfg['interface']
    iopub_port = cfg['iopub']
    iopub_url = "%s:%s" % (reg_url, iopub_port)

    session = Session(key=str_to_bytes(cfg['key']))
    sub = ctx.socket(zmq.SUB)

    # This will subscribe to all messages:
    sub.setsockopt(zmq.SUBSCRIBE, b'')
    # replace with b'' with b'engine.1.stdout' to subscribe only to engine 1's stdout
    # 0MQ subscriptions are simple 'foo*' matches, so 'engine.1.' subscribes
    # to everything from engine 1, but there is no way to subscribe to
    # just stdout from everyone.
    # multiple calls to subscribe will add subscriptions, e.g. to subscribe to
    # engine 1's stderr and engine 2's stdout:
    # sub.setsockopt(zmq.SUBSCRIBE, b'engine.1.stderr')
    # sub.setsockopt(zmq.SUBSCRIBE, b'engine.2.stdout')
    sub.connect(iopub_url)
    while True:
        try:
            idents, msg = session.recv(sub, mode=0)
        except KeyboardInterrupt:
            return
        # ident always length 1 here
        topic = idents[0]
        if msg['msg_type'] == 'stream':
            # stdout/stderr
            # stream names are in msg['content']['name'], if you want to handle
            # them differently
            print("%s: %s" % (topic, msg['content']['text']))
        elif msg['msg_type'] == 'pyerr':
            # Python traceback
            c = msg['content']
            print(topic + ':')
            for line in c['traceback']:
                # indent lines
                print('    ' + line)
コード例 #10
0
    def open(self, kernel_id):
        # Check to see that origin matches host directly, including ports
        if not self.same_origin():
            self.log.warn("Cross Origin WebSocket Attempt.")
            raise web.HTTPError(404)

        self.kernel_id = cast_unicode(kernel_id, 'ascii')
        self.session = Session(config=self.config)
        self.save_on_message = self.on_message
        self.on_message = self.on_first_message
コード例 #11
0
    def open(self, kernel_id):
        self.kernel_id = cast_unicode(kernel_id, 'ascii')
        # Check to see that origin matches host directly, including ports
        # Tornado 4 already does CORS checking
        if tornado.version_info[0] < 4:
            if not self.check_origin(self.get_origin()):
                raise web.HTTPError(403)

        self.session = Session(config=self.config)
        self.save_on_message = self.on_message
        self.on_message = self.on_first_message
コード例 #12
0
 def open(self, kernel_id):
     self.kernel_id = kernel_id.decode('ascii')
     try:
         cfg = self.application.config
     except AttributeError:
         # protect from the case where this is run from something other than
         # the notebook app:
         cfg = None
     self.session = Session(config=cfg)
     self.save_on_message = self.on_message
     self.on_message = self.on_first_message
コード例 #13
0
 def _setup_session(self, reply, comp_id, timeout=None):
     """
     Set up the kernel information contained in the untrusted reply message `reply` from computer `comp_id`.
     """
     reply_content = reply["content"]
     kernel_id = reply_content["kernel_id"]
     kernel_connection = reply_content["connection"]
     self._kernels[kernel_id] = {"comp_id": comp_id,
                                 "connection": kernel_connection,
                                 "executing": 0, # number of active execute_requests
                                 "timeout": timeout if timeout is not None else time.time()+self.kernel_timeout}
     self._comps[comp_id]["kernels"][kernel_id] = None
     self._sessions[kernel_id] = Session(key=kernel_connection["key"])
コード例 #14
0
ファイル: nbbot.py プロジェクト: edwardjkim/stress-proxy
def open_run_save(api, path):
    """open a notebook, run it, and save.
    
    Only the original notebook is saved, the output is not recorded.
    """
    nb = api.get_notebook(path)
    session = Session()
    kernel = yield api.new_kernel(session.session)
    try:
        yield run_notebook(nb, kernel, session)
    finally:
        api.kill_kernel(kernel['id'])
    gen_log.info("Saving %s/notebooks/%s", api.url, path)
    api.save_notebook(nb, path)
コード例 #15
0
    def open(self, kernel_id):
        self.kernel_id = cast_unicode(kernel_id, 'ascii')
        # Check to see that origin matches host directly, including ports
        # Tornado 4 already does CORS checking
        if tornado.version_info[0] < 4:
            if not self.check_origin(self.get_origin()):
                self.log.warn("Cross Origin WebSocket Attempt from %s",
                              self.get_origin())
                raise web.HTTPError(403)

        self.session = Session(config=self.config)
        self.save_on_message = self.on_message
        self.on_message = self.on_first_message
        self.ping_callback = ioloop.PeriodicCallback(self.send_ping,
                                                     WS_PING_INTERVAL)
        self.ping_callback.start()
コード例 #16
0
ファイル: zmqhandlers.py プロジェクト: sunshineca/ipython
    def open(self, kernel_id):
        self.kernel_id = cast_unicode(kernel_id, 'ascii')
        # Check to see that origin matches host directly, including ports
        # Tornado 4 already does CORS checking
        if tornado.version_info[0] < 4:
            if not self.check_origin(self.get_origin()):
                self.log.warn("Cross Origin WebSocket Attempt from %s", self.get_origin())
                raise web.HTTPError(403)

        self.session = Session(config=self.config)
        self.save_on_message = self.on_message
        self.on_message = self.on_first_message
        
        # start the pinging
        if self.ping_interval > 0:
            self.last_ping = ioloop.IOLoop.instance().time()  # Remember time of last ping
            self.last_pong = self.last_ping
            self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
            self.ping_callback.start()
コード例 #17
0
ファイル: handlers.py プロジェクト: maximsch2/ipython
 def open(self, kernel_id):
     self.kernel_id = kernel_id.decode('ascii')
     self.session = Session(config=self.config)
     self.save_on_message = self.on_message
     self.on_message = self.on_first_message
コード例 #18
0
ファイル: test_db.py プロジェクト: 3kwa/ipython
 def setUp(self):
     self.session = Session()
     self.db = self.create_db()
     self.load_records(16)
コード例 #19
0
ファイル: test_db.py プロジェクト: 3kwa/ipython
class TaskDBTest:
    def setUp(self):
        self.session = Session()
        self.db = self.create_db()
        self.load_records(16)
    
    def create_db(self):
        raise NotImplementedError
    
    def load_records(self, n=1, buffer_size=100):
        """load n records for testing"""
        #sleep 1/10 s, to ensure timestamp is different to previous calls
        time.sleep(0.1)
        msg_ids = []
        for i in range(n):
            msg = self.session.msg('apply_request', content=dict(a=5))
            msg['buffers'] = [os.urandom(buffer_size)]
            rec = init_record(msg)
            msg_id = msg['header']['msg_id']
            msg_ids.append(msg_id)
            self.db.add_record(msg_id, rec)
        return msg_ids
    
    def test_add_record(self):
        before = self.db.get_history()
        self.load_records(5)
        after = self.db.get_history()
        self.assertEqual(len(after), len(before)+5)
        self.assertEqual(after[:-5],before)
        
    def test_drop_record(self):
        msg_id = self.load_records()[-1]
        rec = self.db.get_record(msg_id)
        self.db.drop_record(msg_id)
        self.assertRaises(KeyError,self.db.get_record, msg_id)
    
    def _round_to_millisecond(self, dt):
        """necessary because mongodb rounds microseconds"""
        micro = dt.microsecond
        extra = int(str(micro)[-3:])
        return dt - timedelta(microseconds=extra)
    
    def test_update_record(self):
        now = self._round_to_millisecond(datetime.now())
        # 
        msg_id = self.db.get_history()[-1]
        rec1 = self.db.get_record(msg_id)
        data = {'stdout': 'hello there', 'completed' : now}
        self.db.update_record(msg_id, data)
        rec2 = self.db.get_record(msg_id)
        self.assertEqual(rec2['stdout'], 'hello there')
        self.assertEqual(rec2['completed'], now)
        rec1.update(data)
        self.assertEqual(rec1, rec2)
    
    # def test_update_record_bad(self):
    #     """test updating nonexistant records"""
    #     msg_id = str(uuid.uuid4())
    #     data = {'stdout': 'hello there'}
    #     self.assertRaises(KeyError, self.db.update_record, msg_id, data)

    def test_find_records_dt(self):
        """test finding records by date"""
        hist = self.db.get_history()
        middle = self.db.get_record(hist[len(hist)//2])
        tic = middle['submitted']
        before = self.db.find_records({'submitted' : {'$lt' : tic}})
        after = self.db.find_records({'submitted' : {'$gte' : tic}})
        self.assertEqual(len(before)+len(after),len(hist))
        for b in before:
            self.assertTrue(b['submitted'] < tic)
        for a in after:
            self.assertTrue(a['submitted'] >= tic)
        same = self.db.find_records({'submitted' : tic})
        for s in same:
            self.assertTrue(s['submitted'] == tic)
    
    def test_find_records_keys(self):
        """test extracting subset of record keys"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
    
    def test_find_records_msg_id(self):
        """ensure msg_id is always in found records"""
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
        found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
        for rec in found:
            self.assertTrue('msg_id' in rec.keys())
    
    def test_find_records_in(self):
        """test finding records with '$in','$nin' operators"""
        hist = self.db.get_history()
        even = hist[::2]
        odd = hist[1::2]
        recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(even), set(found))
        recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
        found = [ r['msg_id'] for r in recs ]
        self.assertEqual(set(odd), set(found))
    
    def test_get_history(self):
        msg_ids = self.db.get_history()
        latest = datetime(1984,1,1)
        for msg_id in msg_ids:
            rec = self.db.get_record(msg_id)
            newt = rec['submitted']
            self.assertTrue(newt >= latest)
            latest = newt
        msg_id = self.load_records(1)[-1]
        self.assertEqual(self.db.get_history()[-1],msg_id)
    
    def test_datetime(self):
        """get/set timestamps with datetime objects"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['submitted'], datetime))
        self.db.update_record(msg_id, dict(completed=datetime.now()))
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['completed'], datetime))

    def test_drop_matching(self):
        msg_ids = self.load_records(10)
        query = {'msg_id' : {'$in':msg_ids}}
        self.db.drop_matching_records(query)
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)
    
    def test_null(self):
        """test None comparison queries"""
        msg_ids = self.load_records(10)

        query = {'msg_id' : None}
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

        query = {'msg_id' : {'$ne' : None}}
        recs = self.db.find_records(query)
        self.assertTrue(len(recs) >= 10)
    
    def test_pop_safe_get(self):
        """editing query results shouldn't affect record [get]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.get_record(msg_id)
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
    
    def test_pop_safe_find(self):
        """editing query results shouldn't affect record [find]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id})[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find_keys(self):
        """editing query results shouldn't affect record [find+keys]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id' : msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
コード例 #20
0
 def _session_default(self):
     from IPython.kernel.zmq.session import Session
     return Session(config=self.config)
コード例 #21
0
 def setUp(self):
     self.session = Session()
コード例 #22
0
ファイル: manager.py プロジェクト: hack1nt0/CM
 def _session_default(self):
     # don't sign in-process messages
     return Session(key=b'', parent=self)
コード例 #23
0
 def _session_default(self):
     from IPython.kernel.zmq.session import Session
     return Session(parent=self, key=b'')
コード例 #24
0
ファイル: test_db.py プロジェクト: pykomke/Kurz_Python_KE
 def setUp(self):
     self.session = Session()
     self.db = self.create_db()
     self.load_records(16)
コード例 #25
0
    # https://github.com/ipython/ipython/issues/680
    assert isinstance(threading.currentThread(), threading._MainThread)
    try:
        connection_file = "kernel-%s.json" % os.getpid()

        def cleanup_connection_file():
            try:
                os.remove(connection_file)
            except (IOError, OSError):
                pass

        atexit.register(cleanup_connection_file)

        logger = logging.Logger("IPython")
        logger.addHandler(logging.NullHandler())
        session = Session(username=u'kernel')

        context = zmq.Context.instance()
        ip = socket.gethostbyname(socket.gethostname())
        transport = "tcp"
        addr = "%s://%s" % (transport, ip)
        shell_socket = context.socket(zmq.ROUTER)
        shell_port = shell_socket.bind_to_random_port(addr)
        iopub_socket = context.socket(zmq.PUB)
        iopub_port = iopub_socket.bind_to_random_port(addr)
        control_socket = context.socket(zmq.ROUTER)
        control_port = control_socket.bind_to_random_port(addr)

        hb_ctx = zmq.Context()
        heartbeat = Heartbeat(hb_ctx, (transport, ip, 0))
        hb_port = heartbeat.port
コード例 #26
0
ファイル: Debug.py プロジェクト: tazdriver/returnn
def initIPythonKernel():
    # You can remotely connect to this kernel. See the output on stdout.
    try:
        import IPython.kernel.zmq.ipkernel
        from IPython.kernel.zmq.ipkernel import Kernel
        from IPython.kernel.zmq.heartbeat import Heartbeat
        from IPython.kernel.zmq.session import Session
        from IPython.kernel import write_connection_file
        import zmq
        from zmq.eventloop import ioloop
        from zmq.eventloop.zmqstream import ZMQStream
        IPython.kernel.zmq.ipkernel.signal = lambda sig, f: None  # Overwrite.
    except ImportError as e:
        print("IPython import error, cannot start IPython kernel. %s" % e)
        return
    import atexit
    import socket
    import logging
    import threading

    # Do in mainthread to avoid history sqlite DB errors at exit.
    # https://github.com/ipython/ipython/issues/680
    assert isinstance(threading.currentThread(), threading._MainThread)
    try:
        ip = socket.gethostbyname(socket.gethostname())
        connection_file = "ipython-kernel-%s-%s.json" % (ip, os.getpid())

        def cleanup_connection_file():
            try:
                os.remove(connection_file)
            except (IOError, OSError):
                pass

        atexit.register(cleanup_connection_file)

        logger = logging.Logger("IPython")
        logger.addHandler(logging.NullHandler())
        session = Session(username=u'kernel')

        context = zmq.Context.instance()
        transport = "tcp"
        addr = "%s://%s" % (transport, ip)
        shell_socket = context.socket(zmq.ROUTER)
        shell_port = shell_socket.bind_to_random_port(addr)
        iopub_socket = context.socket(zmq.PUB)
        iopub_port = iopub_socket.bind_to_random_port(addr)
        control_socket = context.socket(zmq.ROUTER)
        control_port = control_socket.bind_to_random_port(addr)

        hb_ctx = zmq.Context()
        heartbeat = Heartbeat(hb_ctx, (transport, ip, 0))
        hb_port = heartbeat.port
        heartbeat.start()

        shell_stream = ZMQStream(shell_socket)
        control_stream = ZMQStream(control_socket)

        kernel = Kernel(session=session,
                        shell_streams=[shell_stream, control_stream],
                        iopub_socket=iopub_socket,
                        log=logger)

        write_connection_file(connection_file,
                              shell_port=shell_port,
                              iopub_port=iopub_port,
                              control_port=control_port,
                              hb_port=hb_port,
                              ip=ip)

        #print "To connect another client to this IPython kernel, use:", \
        #      "ipython console --existing %s" % connection_file
    except Exception as e:
        print("Exception while initializing IPython ZMQ kernel. %s" % e)
        return

    def ipython_thread():
        kernel.start()
        try:
            ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            pass

    thread = threading.Thread(target=ipython_thread, name="IPython kernel")
    thread.daemon = True
    thread.start()
コード例 #27
0
 def open(self, kernel_id):
     self.kernel_id = cast_unicode(kernel_id, 'ascii')
     self.session = Session(config=self.config)
     self.save_on_message = self.on_message
     self.on_message = self.on_first_message
コード例 #28
0
ファイル: test_serialize.py プロジェクト: wguo123/ipython
def test_serialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [memoryview(os.urandom(3)) for i in range(3)]
    bmsg = serialize_binary_message(msg)
    nt.assert_is_instance(bmsg, bytes)
コード例 #29
0
ファイル: zmqhandlers.py プロジェクト: otakucode/ipython
 def initialize(self):
     self.log.debug("Initializing websocket connection %s",
                    self.request.path)
     self.session = Session(config=self.config)
コード例 #30
0
ファイル: test_serialize.py プロジェクト: 2t7/ipython
def test_serialize_binary():
    s = Session()
    msg = s.msg('data_pub', content={'a': 'b'})
    msg['buffers'] = [ os.urandom(3) for i in range(3) ]
    bmsg = serialize_binary_message(msg)
    nt.assert_is_instance(bmsg, bytes)
コード例 #31
0
 def init_session(self):
     """create our session object"""
     default_secure(self.config)
     self.session = Session(parent=self, username=u'kernel')
コード例 #32
0
ファイル: test_db.py プロジェクト: pykomke/Kurz_Python_KE
class TaskDBTest:
    def setUp(self):
        self.session = Session()
        self.db = self.create_db()
        self.load_records(16)

    def create_db(self):
        raise NotImplementedError

    def load_records(self, n=1, buffer_size=100):
        """load n records for testing"""
        #sleep 1/10 s, to ensure timestamp is different to previous calls
        time.sleep(0.1)
        msg_ids = []
        for i in range(n):
            msg = self.session.msg('apply_request', content=dict(a=5))
            msg['buffers'] = [os.urandom(buffer_size)]
            rec = init_record(msg)
            msg_id = msg['header']['msg_id']
            msg_ids.append(msg_id)
            self.db.add_record(msg_id, rec)
        return msg_ids

    def test_add_record(self):
        before = self.db.get_history()
        self.load_records(5)
        after = self.db.get_history()
        self.assertEqual(len(after), len(before) + 5)
        self.assertEqual(after[:-5], before)

    def test_drop_record(self):
        msg_id = self.load_records()[-1]
        rec = self.db.get_record(msg_id)
        self.db.drop_record(msg_id)
        self.assertRaises(KeyError, self.db.get_record, msg_id)

    def _round_to_millisecond(self, dt):
        """necessary because mongodb rounds microseconds"""
        micro = dt.microsecond
        extra = int(str(micro)[-3:])
        return dt - timedelta(microseconds=extra)

    def test_update_record(self):
        now = self._round_to_millisecond(datetime.now())
        #
        msg_id = self.db.get_history()[-1]
        rec1 = self.db.get_record(msg_id)
        data = {'stdout': 'hello there', 'completed': now}
        self.db.update_record(msg_id, data)
        rec2 = self.db.get_record(msg_id)
        self.assertEqual(rec2['stdout'], 'hello there')
        self.assertEqual(rec2['completed'], now)
        rec1.update(data)
        self.assertEqual(rec1, rec2)

    # def test_update_record_bad(self):
    #     """test updating nonexistant records"""
    #     msg_id = str(uuid.uuid4())
    #     data = {'stdout': 'hello there'}
    #     self.assertRaises(KeyError, self.db.update_record, msg_id, data)

    def test_find_records_dt(self):
        """test finding records by date"""
        hist = self.db.get_history()
        middle = self.db.get_record(hist[len(hist) // 2])
        tic = middle['submitted']
        before = self.db.find_records({'submitted': {'$lt': tic}})
        after = self.db.find_records({'submitted': {'$gte': tic}})
        self.assertEqual(len(before) + len(after), len(hist))
        for b in before:
            self.assertTrue(b['submitted'] < tic)
        for a in after:
            self.assertTrue(a['submitted'] >= tic)
        same = self.db.find_records({'submitted': tic})
        for s in same:
            self.assertTrue(s['submitted'] == tic)

    def test_find_records_keys(self):
        """test extracting subset of record keys"""
        found = self.db.find_records({'msg_id': {
            '$ne': ''
        }},
                                     keys=['submitted', 'completed'])
        for rec in found:
            self.assertEqual(set(rec.keys()),
                             set(['msg_id', 'submitted', 'completed']))

    def test_find_records_msg_id(self):
        """ensure msg_id is always in found records"""
        found = self.db.find_records({'msg_id': {
            '$ne': ''
        }},
                                     keys=['submitted', 'completed'])
        for rec in found:
            self.assertTrue('msg_id' in list(rec.keys()))
        found = self.db.find_records({'msg_id': {
            '$ne': ''
        }},
                                     keys=['submitted'])
        for rec in found:
            self.assertTrue('msg_id' in list(rec.keys()))
        found = self.db.find_records({'msg_id': {'$ne': ''}}, keys=['msg_id'])
        for rec in found:
            self.assertTrue('msg_id' in list(rec.keys()))

    def test_find_records_in(self):
        """test finding records with '$in','$nin' operators"""
        hist = self.db.get_history()
        even = hist[::2]
        odd = hist[1::2]
        recs = self.db.find_records({'msg_id': {'$in': even}})
        found = [r['msg_id'] for r in recs]
        self.assertEqual(set(even), set(found))
        recs = self.db.find_records({'msg_id': {'$nin': even}})
        found = [r['msg_id'] for r in recs]
        self.assertEqual(set(odd), set(found))

    def test_get_history(self):
        msg_ids = self.db.get_history()
        latest = datetime(1984, 1, 1)
        for msg_id in msg_ids:
            rec = self.db.get_record(msg_id)
            newt = rec['submitted']
            self.assertTrue(newt >= latest)
            latest = newt
        msg_id = self.load_records(1)[-1]
        self.assertEqual(self.db.get_history()[-1], msg_id)

    def test_datetime(self):
        """get/set timestamps with datetime objects"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['submitted'], datetime))
        self.db.update_record(msg_id, dict(completed=datetime.now()))
        rec = self.db.get_record(msg_id)
        self.assertTrue(isinstance(rec['completed'], datetime))

    def test_drop_matching(self):
        msg_ids = self.load_records(10)
        query = {'msg_id': {'$in': msg_ids}}
        self.db.drop_matching_records(query)
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

    def test_null(self):
        """test None comparison queries"""
        msg_ids = self.load_records(10)

        query = {'msg_id': None}
        recs = self.db.find_records(query)
        self.assertEqual(len(recs), 0)

        query = {'msg_id': {'$ne': None}}
        recs = self.db.find_records(query)
        self.assertTrue(len(recs) >= 10)

    def test_pop_safe_get(self):
        """editing query results shouldn't affect record [get]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.get_record(msg_id)
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.get_record(msg_id)
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find(self):
        """editing query results shouldn't affect record [find]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id': msg_id})[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id': msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)

    def test_pop_safe_find_keys(self):
        """editing query results shouldn't affect record [find+keys]"""
        msg_id = self.db.get_history()[-1]
        rec = self.db.find_records({'msg_id': msg_id},
                                   keys=['buffers', 'header'])[0]
        rec.pop('buffers')
        rec['garbage'] = 'hello'
        rec['header']['msg_id'] = 'fubar'
        rec2 = self.db.find_records({'msg_id': msg_id})[0]
        self.assertTrue('buffers' in rec2)
        self.assertFalse('garbage' in rec2)
        self.assertEqual(rec2['header']['msg_id'], msg_id)
コード例 #33
0
 def initialize(self):
     self.session = Session(config=self.config)