Ejemplo n.º 1
0
 def test_pyzmq_version_info(self):
     version = zmq.sugar.version
     save = version.__version__
     try:
         version.__version__ = '2.10dev'
         info = zmq.pyzmq_version_info()
         self.assertTrue(isinstance(info, tuple))
         self.assertEquals(len(info), 3)
         self.assertTrue(info > (2,10,99))
         self.assertEquals(info, (2,10,float('inf')))
         version.__version__ = '2.1.10'
         info = zmq.pyzmq_version_info()
         self.assertEquals(info, (2,1,10))
         self.assertTrue(info > (2,1,9))
     finally:
         version.__version__ = save
Ejemplo n.º 2
0
    def enable_monitor(self, events=None):

        # The standard approach of binding and then connecting does not
        # work in this specific case. The event loop does not properly
        # detect messages on the inproc transport which means that event
        # messages get missed.
        # pyzmq's 'get_monitor_socket' method can't be used because this
        # performs the actions in the wrong order for use with an event
        # loop.
        # For more information on this issue see:
        # http://lists.zeromq.org/pipermail/zeromq-dev/2015-July/029181.html

        if (zmq.zmq_version_info() < (4,) or
                zmq.pyzmq_version_info() < (14, 4,)):
            raise NotImplementedError(
                "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4, "
                "have libzmq:{}, pyzmq:{}".format(
                    zmq.zmq_version(), zmq.pyzmq_version()))

        if self._monitor is None:
            addr = "inproc://monitor.s-{}".format(self._zmq_sock.FD)
            events = events or zmq.EVENT_ALL
            _, self._monitor = yield from create_zmq_connection(
                lambda: _ZmqEventProtocol(self._loop, self._protocol),
                zmq.PAIR, connect=addr, loop=self._loop)
            # bind must come after connect
            self._zmq_sock.monitor(addr, events)
            yield from self._monitor.wait_ready
Ejemplo n.º 3
0
 def test_pyzmq_version_info(self):
     version = zmq.core.version
     save = version.__version__
     try:
         version.__version__ = '2.10dev'
         info = zmq.pyzmq_version_info()
         self.assertTrue(isinstance(info, tuple))
         self.assertEquals(len(info), 3)
         self.assertTrue(info > (2, 10, 99))
         self.assertEquals(info, (2, 10, float('inf')))
         version.__version__ = '2.1.10'
         info = zmq.pyzmq_version_info()
         self.assertEquals(info, (2, 1, 10))
         self.assertTrue(info > (2, 1, 9))
     finally:
         version.__version__ = save
Ejemplo n.º 4
0
    def enable_monitor(self, events=None):

        # The standard approach of binding and then connecting does not
        # work in this specific case. The event loop does not properly
        # detect messages on the inproc transport which means that event
        # messages get missed.
        # pyzmq's 'get_monitor_socket' method can't be used because this
        # performs the actions in the wrong order for use with an event
        # loop.
        # For more information on this issue see:
        # http://lists.zeromq.org/pipermail/zeromq-dev/2015-July/029181.html

        if (zmq.zmq_version_info() < (4, ) or zmq.pyzmq_version_info() < (
                14,
                4,
        )):
            raise NotImplementedError(
                "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4, "
                "have libzmq:{}, pyzmq:{}".format(zmq.zmq_version(),
                                                  zmq.pyzmq_version()))

        if self._monitor is None:
            addr = "inproc://monitor.s-{}".format(self._zmq_sock.FD)
            events = events or zmq.EVENT_ALL
            _, self._monitor = yield from create_zmq_connection(
                lambda: _ZmqEventProtocol(self._loop, self._protocol),
                zmq.PAIR,
                connect=addr,
                loop=self._loop)
            # bind must come after connect
            self._zmq_sock.monitor(addr, events)
            yield from self._monitor.wait_ready
Ejemplo n.º 5
0
def check_for_pyzmq():
    try:
        import zmq
    except ImportError:
        print_status(
            'pyzmq',
            "no (required for qtconsole, notebook, and parallel computing capabilities)"
        )
        return False
    else:
        # pyzmq 2.1.10 adds pyzmq_version_info funtion for returning
        # version as a tuple
        if hasattr(zmq, 'pyzmq_version_info'):
            if zmq.pyzmq_version_info() >= (2, 1, 4):
                print_status("pyzmq", zmq.__version__)
                return True
            else:
                # this branch can never occur, at least until we update our
                # pyzmq dependency beyond 2.1.10
                return False
        # this is necessarily earlier than 2.1.10, so string comparison is
        # okay
        if zmq.__version__ < '2.1.4':
            print_status(
                'pyzmq', "no (have %s, but require >= 2.1.4 for"
                " qtconsole and parallel computing capabilities)" %
                zmq.__version__)
            return False
        else:
            print_status("pyzmq", zmq.__version__)
            return True
Ejemplo n.º 6
0
def _zmq_has_curve():
    """
    Return whether the current ZMQ has support for auth and CurveZMQ security.

    :rtype: bool

     Version notes:
       `zmq.curve_keypair()` is new in version 14.0, new in version libzmq-4.0.
            Requires libzmq (>= 4.0) to have been linked with libsodium.
       `zmq.auth` module is new in version 14.1
       `zmq.has()` is new in version 14.1, new in version libzmq-4.1.
    """
    zmq_version = zmq.zmq_version_info()
    pyzmq_version = zmq.pyzmq_version_info()

    if pyzmq_version >= (14, 1, 0) and zmq_version >= (4, 1):
        return zmq.has('curve')

    if pyzmq_version < (14, 1, 0):
        return False

    if zmq_version < (4, 0):
        # security is new in libzmq 4.0
        return False

    try:
        zmq.curve_keypair()
    except zmq.error.ZMQError:
        # security requires libzmq to be linked against libsodium
        return False

    return True
Ejemplo n.º 7
0
def check_for_pyzmq():
    try:
        import zmq
    except ImportError:
        print_status('pyzmq', "no (required for qtconsole, notebook, and parallel computing capabilities)")
        return False
    else:
        # pyzmq 2.1.10 adds pyzmq_version_info funtion for returning
        # version as a tuple
        if hasattr(zmq, 'pyzmq_version_info'):
            if zmq.pyzmq_version_info() >= (2,1,4):
                print_status("pyzmq", zmq.__version__)
                return True
            else:
                # this branch can never occur, at least until we update our
                # pyzmq dependency beyond 2.1.10
                return False
        # this is necessarily earlier than 2.1.10, so string comparison is
        # okay
        if zmq.__version__ < '2.1.4':
            print_status('pyzmq', "no (have %s, but require >= 2.1.4 for"
            " qtconsole and parallel computing capabilities)"%zmq.__version__)
            return False
        else:
            print_status("pyzmq", zmq.__version__)
            return True
Ejemplo n.º 8
0
 def test_pyzmq_version_info(self):
     info = zmq.pyzmq_version_info()
     self.assertTrue(isinstance(info, tuple))
     for n in info[:3]:
         self.assertTrue(isinstance(n, int))
     if version.VERSION_EXTRA:
         self.assertEqual(len(info), 4)
         self.assertEqual(info[-1], float('inf'))
     else:
         self.assertEqual(len(info), 3)
Ejemplo n.º 9
0
 def test_pyzmq_version_info(self):
     info = zmq.pyzmq_version_info()
     self.assertTrue(isinstance(info, tuple))
     for n in info[:3]:
         self.assertTrue(isinstance(n, int))
     if version.VERSION_EXTRA:
         self.assertEqual(len(info), 4)
         self.assertEqual(info[-1], float('inf'))
     else:
         self.assertEqual(len(info), 3)
Ejemplo n.º 10
0
def check_for_pyzmq():
    try:
        import zmq
    except ImportError:
        print_status('pyzmq', "no (required for qtconsole, notebook, and parallel computing capabilities)")
        return False
    else:
        # pyzmq 2.1.10 adds pyzmq_version_info funtion for returning
        # version as a tuple
        if hasattr(zmq, 'pyzmq_version_info') and zmq.pyzmq_version_info() >= (2,1,11):
                print_status("pyzmq", zmq.__version__)
                return True
        else:
            print_status('pyzmq', "no (have %s, but require >= 2.1.11 for"
            " qtconsole, notebook, and parallel computing capabilities)" % zmq.__version__)
            return False
Ejemplo n.º 11
0
 def init_signal(self):
     # FIXME: remove this check when pyzmq dependency is >= 2.1.11
     # safely extract zmq version info:
     try:
         zmq_v = zmq.pyzmq_version_info()
     except AttributeError:
         zmq_v = [ int(n) for n in re.findall(r'\d+', zmq.__version__) ]
         if 'dev' in zmq.__version__:
             zmq_v.append(999)
         zmq_v = tuple(zmq_v)
     if zmq_v >= (2,1,9):
         # This won't work with 2.1.7 and
         # 2.1.9-10 will log ugly 'Interrupted system call' messages,
         # but it will work
         signal.signal(signal.SIGINT, self._handle_sigint)
     signal.signal(signal.SIGTERM, self._signal_stop)
Ejemplo n.º 12
0
 def init_signal(self):
     # FIXME: remove this check when pyzmq dependency is >= 2.1.11
     # safely extract zmq version info:
     try:
         zmq_v = zmq.pyzmq_version_info()
     except AttributeError:
         zmq_v = [int(n) for n in re.findall(r'\d+', zmq.__version__)]
         if 'dev' in zmq.__version__:
             zmq_v.append(999)
         zmq_v = tuple(zmq_v)
     if zmq_v >= (2, 1, 9) and not sys.platform.startswith('win'):
         # This won't work with 2.1.7 and
         # 2.1.9-10 will log ugly 'Interrupted system call' messages,
         # but it will work
         signal.signal(signal.SIGINT, self._handle_sigint)
     signal.signal(signal.SIGTERM, self._signal_stop)
Ejemplo n.º 13
0
def check_for_pyzmq():
    try:
        import zmq
    except ImportError:
        print_status('pyzmq', "no (required for qtconsole, notebook, and parallel computing capabilities)")
        return False
    else:
        # pyzmq 2.1.10 adds pyzmq_version_info funtion for returning
        # version as a tuple
        if hasattr(zmq, 'pyzmq_version_info') and zmq.pyzmq_version_info() >= (2,1,11):
                print_status("pyzmq", zmq.__version__)
                return True
        else:
            print_status('pyzmq', "no (have %s, but require >= 2.1.11 for"
            " qtconsole, notebook, and parallel computing capabilities)" % zmq.__version__)
            return False
Ejemplo n.º 14
0
 def init_signal(self):
     # FIXME: remove this check when pyzmq dependency is >= 2.1.11
     # safely extract zmq version info:
     try:
         zmq_v = zmq.pyzmq_version_info()
     except AttributeError:
         zmq_v = [ int(n) for n in re.findall(r'\d+', zmq.__version__) ]
         if 'dev' in zmq.__version__:
             zmq_v.append(999)
         zmq_v = tuple(zmq_v)
     if zmq_v >= (2,1,9) and not sys.platform.startswith('win'):
         # This won't work with 2.1.7 and
         # 2.1.9-10 will log ugly 'Interrupted system call' messages,
         # but it will work
         signal.signal(signal.SIGINT, self._handle_sigint)
     signal.signal(signal.SIGTERM, self._signal_stop)
     if hasattr(signal, 'SIGUSR1'):
         # Windows doesn't support SIGUSR1
         signal.signal(signal.SIGUSR1, self._signal_info)
     if hasattr(signal, 'SIGINFO'):
         # only on BSD-based systems
         signal.signal(signal.SIGINFO, self._signal_info)
Ejemplo n.º 15
0
class BaseZmqEventLoopTestsMixin:

    @asyncio.coroutine
    def make_dealer_router(self):
        port = find_unused_port()

        tr1, pr1 = yield from aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.DEALER,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop)
        self.assertEqual('CONNECTED', pr1.state)
        yield from pr1.connected

        tr2, pr2 = yield from aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.ROUTER,
            connect='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop)
        self.assertEqual('CONNECTED', pr2.state)
        yield from pr2.connected

        return tr1, pr1, tr2, pr2

    @asyncio.coroutine
    def make_pub_sub(self):
        port = find_unused_port()

        tr1, pr1 = yield from aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.PUB,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop)
        self.assertEqual('CONNECTED', pr1.state)
        yield from pr1.connected

        tr2, pr2 = yield from aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.SUB,
            connect='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop)
        self.assertEqual('CONNECTED', pr2.state)
        yield from pr2.connected

        return tr1, pr1, tr2, pr2

    def test_req_rep(self):
        @asyncio.coroutine
        def connect_req():
            tr1, pr1 = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind='inproc://test',
                loop=self.loop)
            self.assertEqual('CONNECTED', pr1.state)
            yield from pr1.connected
            return tr1, pr1

        tr1, pr1 = self.loop.run_until_complete(connect_req())

        @asyncio.coroutine
        def connect_rep():
            tr2, pr2 = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REP,
                connect='inproc://test',
                loop=self.loop)
            self.assertEqual('CONNECTED', pr2.state)
            yield from pr2.connected
            return tr2, pr2

        tr2, pr2 = self.loop.run_until_complete(connect_rep())

        @asyncio.coroutine
        def communicate():
            tr1.write([b'request'])
            request = yield from pr2.received.get()
            self.assertEqual([b'request'], request)
            tr2.write([b'answer'])
            answer = yield from pr1.received.get()
            self.assertEqual([b'answer'], answer)

        self.loop.run_until_complete(communicate())

        @asyncio.coroutine
        def closing():
            tr1.close()
            tr2.close()

            yield from pr1.closed
            self.assertEqual('CLOSED', pr1.state)
            yield from pr2.closed
            self.assertEqual('CLOSED', pr2.state)

        self.loop.run_until_complete(closing())

    def test_pub_sub(self):

        @asyncio.coroutine
        def go():
            tr1, pr1, tr2, pr2 = yield from self.make_pub_sub()
            tr2.setsockopt(zmq.SUBSCRIBE, b'node_id')

            for i in range(5):
                tr1.write([b'node_id', b'publish'])
                try:
                    request = yield from asyncio.wait_for(pr2.received.get(),
                                                          0.1,
                                                          loop=self.loop)
                    self.assertEqual([b'node_id', b'publish'], request)
                    break
                except asyncio.TimeoutError:
                    pass
            else:
                raise AssertionError("Cannot get message in subscriber")

            tr1.close()
            tr2.close()
            yield from pr1.closed
            self.assertEqual('CLOSED', pr1.state)
            yield from pr2.closed
            self.assertEqual('CLOSED', pr2.state)

        self.loop.run_until_complete(go())

    def test_getsockopt(self):
        port = find_unused_port()

        @asyncio.coroutine
        def coro():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from pr.connected
            self.assertEqual(zmq.DEALER, tr.getsockopt(zmq.TYPE))
            return tr, pr

        self.loop.run_until_complete(coro())

    def test_dealer_router(self):
        @asyncio.coroutine
        def go():
            tr1, pr1, tr2, pr2 = yield from self.make_dealer_router()
            tr1.write([b'request'])
            request = yield from pr2.received.get()
            self.assertEqual([mock.ANY, b'request'], request)
            tr2.write([request[0], b'answer'])
            answer = yield from pr1.received.get()
            self.assertEqual([b'answer'], answer)

            tr1.close()
            tr2.close()

            yield from pr1.closed
            self.assertEqual('CLOSED', pr1.state)
            yield from pr2.closed
            self.assertEqual('CLOSED', pr2.state)

        self.loop.run_until_complete(go())

    def test_binds(self):
        port1 = find_unused_port()
        port2 = find_unused_port()
        addr1 = 'tcp://127.0.0.1:{}'.format(port1)
        addr2 = 'tcp://127.0.0.1:{}'.format(port2)

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind=[addr1, addr2],
                loop=self.loop)
            yield from pr.connected

            self.assertEqual({addr1, addr2}, tr.bindings())

            addr3 = yield from tr.bind('tcp://127.0.0.1:*')
            self.assertEqual({addr1, addr2, addr3}, tr.bindings())
            yield from tr.unbind(addr2)
            self.assertEqual({addr1, addr3}, tr.bindings())
            self.assertIn(addr1, tr.bindings())
            self.assertRegex(repr(tr.bindings()),
                             r'{tcp://127.0.0.1:.\d+, tcp://127.0.0.1:\d+}')
            tr.close()

        self.loop.run_until_complete(connect())

    def test_connects(self):
        port1 = find_unused_port()
        port2 = find_unused_port()
        port3 = find_unused_port()
        addr1 = 'tcp://127.0.0.1:{}'.format(port1)
        addr2 = 'tcp://127.0.0.1:{}'.format(port2)
        addr3 = 'tcp://127.0.0.1:{}'.format(port3)

        @asyncio.coroutine
        def go():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                connect=[addr1, addr2],
                loop=self.loop)
            yield from pr.connected

            self.assertEqual({addr1, addr2}, tr.connections())
            yield from tr.connect(addr3)
            self.assertEqual({addr1, addr3, addr2}, tr.connections())
            yield from tr.disconnect(addr1)
            self.assertEqual({addr2, addr3}, tr.connections())
            tr.close()

        self.loop.run_until_complete(go())

    def test_zmq_socket(self):
        zmq_sock = zmq.Context.instance().socket(zmq.PUB)

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.PUB,
                zmq_sock=zmq_sock,
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())
        self.assertIs(zmq_sock, tr._zmq_sock)
        self.assertFalse(zmq_sock.closed)
        tr.close()

    def test_zmq_socket_invalid_type(self):
        zmq_sock = zmq.Context.instance().socket(zmq.PUB)

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                zmq_sock=zmq_sock,
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        with self.assertRaises(ValueError):
            self.loop.run_until_complete(connect())
        self.assertFalse(zmq_sock.closed)

    def test_create_zmq_connection_ZMQError(self):
        zmq_sock = zmq.Context.instance().socket(zmq.PUB)
        zmq_sock.close()

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                zmq_sock=zmq_sock,
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        with self.assertRaises(OSError) as ctx:
            self.loop.run_until_complete(connect())
        self.assertIn(ctx.exception.errno, (zmq.ENOTSUP, zmq.ENOTSOCK))

    def test_create_zmq_connection_invalid_bind(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                bind=2,
                loop=self.loop)

        with self.assertRaises(ValueError):
            self.loop.run_until_complete(connect())

    def test_create_zmq_connection_invalid_connect(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                connect=2,
                loop=self.loop)

        with self.assertRaises(ValueError):
            self.loop.run_until_complete(connect())

    @unittest.skipIf(sys.platform == 'win32',
                     "Windows calls abort() on bad socket")
    def test_create_zmq_connection_closes_socket_on_bad_bind(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                bind='badaddr',
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        with self.assertRaises(OSError):
            self.loop.run_until_complete(connect())

    @unittest.skipIf(sys.platform == 'win32',
                     "Windows calls abort() on bad socket")
    def test_create_zmq_connection_closes_socket_on_bad_connect(self):

        @asyncio.coroutine
        def connect():
            with self.assertRaises(OSError):
                yield from aiozmq.create_zmq_connection(
                    lambda: Protocol(self.loop),
                    zmq.SUB,
                    connect='badaddr',
                    loop=self.loop)

        self.loop.run_until_complete(connect())

    def test_create_zmq_connection_dns_in_connect(self):
        port = find_unused_port()

        @asyncio.coroutine
        def connect():
            addr = 'tcp://localhost:{}'.format(port)
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                connect=addr,
                loop=self.loop)
            yield from pr.connected

            self.assertEqual({addr}, tr.connections())
            tr.close()

        self.loop.run_until_complete(connect())

    def test_getsockopt_badopt(self):
        port = find_unused_port()

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                connect='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())

        with self.assertRaises(OSError) as ctx:
            tr.getsockopt(1111)  # invalid option
        self.assertEqual(zmq.EINVAL, ctx.exception.errno)

    def test_setsockopt_badopt(self):
        port = find_unused_port()

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                connect='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())

        with self.assertRaises(OSError) as ctx:
            tr.setsockopt(1111, 1)  # invalid option
        self.assertEqual(zmq.EINVAL, ctx.exception.errno)

    def test_unbind_from_nonbinded_addr(self):
        port = find_unused_port()
        addr = 'tcp://127.0.0.1:{}'.format(port)

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                bind=addr,
                loop=self.loop)
            yield from pr.connected

            self.assertEqual({addr}, tr.bindings())
            with self.assertRaises(OSError) as ctx:
                yield from tr.unbind('ipc:///some-addr')  # non-bound addr

            # TODO: check travis build and remove skip when test passed.
            if (ctx.exception.errno == zmq.EAGAIN and
                    os.environ.get('TRAVIS')):
                raise unittest.SkipTest("Travis has a bug, it returns "
                                        "EAGAIN for unknown endpoint")
            self.assertIn(ctx.exception.errno,
                          (errno.ENOENT, zmq.EPROTONOSUPPORT))
            self.assertEqual({addr}, tr.bindings())

        self.loop.run_until_complete(connect())

    def test_disconnect_from_nonbinded_addr(self):
        port = find_unused_port()
        addr = 'tcp://127.0.0.1:{}'.format(port)

        @asyncio.coroutine
        def go():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                connect=addr,
                loop=self.loop)
            yield from pr.connected

            self.assertEqual({addr}, tr.connections())
            with self.assertRaises(OSError) as ctx:
                yield from tr.disconnect('ipc:///some-addr')  # non-bound addr

            # TODO: check travis build and remove skip when test passed.
            if (ctx.exception.errno == zmq.EAGAIN and
                    os.environ.get('TRAVIS')):
                raise unittest.SkipTest("Travis has a bug, it returns "
                                        "EAGAIN for unknown endpoint")
            self.assertIn(ctx.exception.errno,
                          (errno.ENOENT, zmq.EPROTONOSUPPORT))
            self.assertEqual({addr}, tr.connections())

        self.loop.run_until_complete(go())

    def test_subscriptions_of_invalid_socket(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.PUSH,
                bind='tcp://127.0.0.1:*',
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())
        self.assertRaises(NotImplementedError, tr.subscribe, b'a')
        self.assertRaises(NotImplementedError, tr.unsubscribe, b'a')
        self.assertRaises(NotImplementedError, tr.subscriptions)

    def test_double_subscribe(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                bind='tcp://127.0.0.1:*',
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())
        tr.subscribe(b'val')
        self.assertEqual({b'val'}, tr.subscriptions())

        tr.subscribe(b'val')
        self.assertEqual({b'val'}, tr.subscriptions())

    def test_double_unsubscribe(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                bind='tcp://127.0.0.1:*',
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        try:
            tr, pr = self.loop.run_until_complete(connect())
            tr.subscribe(b'val')
            self.assertEqual({b'val'}, tr.subscriptions())

            tr.unsubscribe(b'val')
            self.assertFalse(tr.subscriptions())
            tr.unsubscribe(b'val')
            self.assertFalse(tr.subscriptions())
        except OSError as exc:
            if exc.errno == errno.ENOTSOCK:
                # I'm sad but ZMQ sometimes throws that error
                raise unittest.SkipTest("Malformed answer")

    def test_unsubscribe_unknown_filter(self):

        @asyncio.coroutine
        def connect():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.SUB,
                bind='tcp://127.0.0.1:*',
                loop=self.loop)
            yield from pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())

        tr.unsubscribe(b'val')
        self.assertFalse(tr.subscriptions())
        tr.unsubscribe(b'val')
        self.assertFalse(tr.subscriptions())

    def test_endpoint_is_not_a_str(self):

        @asyncio.coroutine
        def go():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.PUSH,
                bind='tcp://127.0.0.1:*',
                loop=self.loop)
            yield from pr.connected

            with self.assertRaises(TypeError):
                yield from tr.bind(123)

            with self.assertRaises(TypeError):
                yield from tr.unbind(123)

            with self.assertRaises(TypeError):
                yield from tr.connect(123)

            with self.assertRaises(TypeError):
                yield from tr.disconnect(123)

        self.loop.run_until_complete(go())

    def test_transfer_big_data(self):

        @asyncio.coroutine
        def go():
            tr1, pr1, tr2, pr2 = yield from self.make_dealer_router()

            start = 65
            cnt = 26
            data = [chr(i).encode('ascii')*1000
                    for i in range(start, start+cnt)]

            for i in range(2000):
                tr1.write(data)

            request = yield from pr2.received.get()
            self.assertEqual([mock.ANY] + data, request)

            tr1.close()
            tr2.close()

        self.loop.run_until_complete(go())

    def test_transfer_big_data_send_after_closing(self):

        @asyncio.coroutine
        def go():
            tr1, pr1, tr2, pr2 = yield from self.make_dealer_router()

            start = 65
            cnt = 26
            data = [chr(i).encode('ascii')*1000
                    for i in range(start, start+cnt)]

            self.assertFalse(pr1.paused)

            for i in range(10000):
                tr1.write(data)

            self.assertTrue(tr1._buffer)
            self.assertTrue(pr1.paused)
            tr1.close()

            for i in range(10000):
                request = yield from pr2.received.get()
                self.assertEqual([mock.ANY] + data, request)
            tr2.close()

        self.loop.run_until_complete(go())

    def test_default_event_loop(self):
        asyncio.set_event_loop(self.loop)
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port)))
        self.assertIs(self.loop, tr1._loop)
        tr1.close()

    def test_close_closing(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        tr1.close()
        self.assertTrue(tr1._closing)
        tr1.close()
        self.assertTrue(tr1._closing)

    def test_pause_reading(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        self.assertFalse(tr1._paused)
        tr1.pause_reading()
        self.assertTrue(tr1._paused)
        tr1.resume_reading()
        self.assertFalse(tr1._paused)
        tr1.close()

    def test_pause_reading_closed(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        tr1.close()
        with self.assertRaises(RuntimeError):
            tr1.pause_reading()

    def test_pause_reading_paused(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        tr1.pause_reading()
        self.assertTrue(tr1._paused)
        with self.assertRaises(RuntimeError):
            tr1.pause_reading()
        tr1.close()

    def test_resume_reading_not_paused(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        with self.assertRaises(RuntimeError):
            tr1.resume_reading()
        tr1.close()

    @mock.patch('aiozmq.core.logger')
    def test_warning_on_connection_lost(self, m_log):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        self.assertEqual(0, tr1._conn_lost)
        tr1.LOG_THRESHOLD_FOR_CONNLOST_WRITES = 2
        tr1.close()
        self.assertEqual(1, tr1._conn_lost)
        tr1.write([b'data'])
        self.assertEqual(2, tr1._conn_lost)
        self.assertFalse(m_log.warning.called)
        tr1.write([b'data'])
        self.assertEqual(3, tr1._conn_lost)
        m_log.warning.assert_called_with('write to closed ZMQ socket.')

    def test_close_on_error(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        handler = mock.Mock()
        self.loop.set_exception_handler(handler)
        sock = tr1.get_extra_info('zmq_socket')
        sock.close()
        tr1.write([b'data'])
        self.assertTrue(tr1._closing)
        handler.assert_called_with(
            self.loop,
            {'protocol': pr1,
             'exception': mock.ANY,
             'transport': tr1,
             'message': 'Fatal write error on zmq socket transport'})
        # expecting 'Socket operation on non-socket'
        if sys.platform == 'darwin':
            errno = 38
        else:
            errno = 88
        check_errno(errno, handler.call_args[0][1]['exception'])

    def test_double_force_close(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.REQ,
            bind='tcp://127.0.0.1:{}'.format(port),
            loop=self.loop))
        handler = mock.Mock()
        self.loop.set_exception_handler(handler)
        err = RuntimeError('error')
        tr1._fatal_error(err)
        tr1._fatal_error(err)
        self.loop.run_until_complete(pr1.closed)

    def test___repr__(self):
        port = find_unused_port()

        @asyncio.coroutine
        def coro():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from pr.connected
            self.assertRegex(
                repr(tr),
                '<ZmqTransport sock=<[^>]+> '
                'type=DEALER read=idle write=<idle, bufsize=0>>')
            tr.close()

        self.loop.run_until_complete(coro())

    def test_extra_zmq_type(self):
        port = find_unused_port()

        @asyncio.coroutine
        def coro():
            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from pr.connected

            self.assertEqual(zmq.DEALER, tr.get_extra_info('zmq_type'))
            tr.close()

        self.loop.run_until_complete(coro())

    @unittest.skipIf(
        zmq.zmq_version_info() < (4,) or zmq.pyzmq_version_info() < (14, 4,),
        "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4")
    def test_implicit_monitor_disable(self):

        @asyncio.coroutine
        def go():

            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                loop=self.loop)
            yield from pr.connected

            yield from tr.enable_monitor()

            tr.close()
            yield from pr.closed

            self.assertIsNone(tr._monitor)

        self.loop.run_until_complete(go())

    @unittest.skipIf(
        zmq.zmq_version_info() < (4,) or zmq.pyzmq_version_info() < (14, 4,),
        "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4")
    def test_force_close_monitor(self):

        @asyncio.coroutine
        def go():

            tr, pr = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                loop=self.loop)
            yield from pr.connected

            yield from tr.enable_monitor()

            tr.abort()
            yield from pr.closed

            self.assertIsNone(tr._monitor)

        self.loop.run_until_complete(go())
Ejemplo n.º 16
0
Archivo: setup.py Proyecto: iquaba/salt
    def get_esky_freezer_includes(self):
        # Sometimes the auto module traversal doesn't find everything, so we
        # explicitly add it. The auto dependency tracking especially does not work for
        # imports occurring in salt.modules, as they are loaded at salt runtime.
        # Specifying includes that don't exist doesn't appear to cause a freezing
        # error.
        freezer_includes = [
            'zmq.core.*',
            'zmq.utils.*',
            'ast',
            'csv',
            'difflib',
            'distutils',
            'distutils.version',
            'numbers',
            'json',
            'M2Crypto',
            'Cookie',
            'asyncore',
            'fileinput',
            'sqlite3',
            'email',
            'email.mime.*',
            'requests',
            'sqlite3',
        ]
        if HAS_ZMQ and hasattr(zmq, 'pyzmq_version_info'):
            if HAS_ZMQ and zmq.pyzmq_version_info() >= (0, 14):
                # We're freezing, and when freezing ZMQ needs to be installed, so this
                # works fine
                if 'zmq.core.*' in freezer_includes:
                    # For PyZMQ >= 0.14, freezing does not need 'zmq.core.*'
                    freezer_includes.remove('zmq.core.*')

        if IS_WINDOWS_PLATFORM:
            freezer_includes.extend([
                'imp',
                'win32api',
                'win32file',
                'win32con',
                'win32com',
                'win32net',
                'win32netcon',
                'win32gui',
                'win32security',
                'ntsecuritycon',
                'pywintypes',
                'pythoncom',
                '_winreg',
                'wmi',
                'site',
                'psutil',
            ])
        elif sys.platform.startswith('linux'):
            freezer_includes.append('spwd')
            try:
                import yum  # pylint: disable=unused-variable
                freezer_includes.append('yum')
            except ImportError:
                pass
        elif sys.platform.startswith('sunos'):
            # (The sledgehammer approach)
            # Just try to include everything
            # (This may be a better way to generate freezer_includes generally)
            try:
                from bbfreeze.modulegraph.modulegraph import ModuleGraph
                mgraph = ModuleGraph(sys.path[:])
                for arg in glob.glob('salt/modules/*.py'):
                    mgraph.run_script(arg)
                for mod in mgraph.flatten():
                    if type(mod).__name__ != 'Script' and mod.filename:
                        freezer_includes.append(str(os.path.basename(mod.identifier)))
            except ImportError:
                pass
            # Include C extension that convinces esky to package up the libsodium C library
            # This is needed for ctypes to find it in libnacl which is in turn needed for raet
            # see pkg/smartos/esky/sodium_grabber{.c,_installer.py}
            freezer_includes.extend([
                'sodium_grabber',
                'ioflo',
                'raet',
                'libnacl',
            ])
        return freezer_includes
Ejemplo n.º 17
0
class BaseZmqEventLoopTestsMixin:
    async def make_dealer_router(self):
        port = find_unused_port()

        tr1, pr1 = await aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.DEALER,
            bind="tcp://127.0.0.1:{}".format(port),
        )
        self.assertEqual("CONNECTED", pr1.state)
        await pr1.connected

        tr2, pr2 = await aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.ROUTER,
            connect="tcp://127.0.0.1:{}".format(port),
        )
        self.assertEqual("CONNECTED", pr2.state)
        await pr2.connected

        return tr1, pr1, tr2, pr2

    async def make_pub_sub(self):
        port = find_unused_port()

        tr1, pr1 = await aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.PUB,
            bind="tcp://127.0.0.1:{}".format(port),
        )
        self.assertEqual("CONNECTED", pr1.state)
        await pr1.connected

        tr2, pr2 = await aiozmq.create_zmq_connection(
            lambda: Protocol(self.loop),
            zmq.SUB,
            connect="tcp://127.0.0.1:{}".format(port),
        )
        self.assertEqual("CONNECTED", pr2.state)
        await pr2.connected

        return tr1, pr1, tr2, pr2

    def test_req_rep(self):
        async def connect_req():
            tr1, pr1 = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind="inproc://test",
            )
            self.assertEqual("CONNECTED", pr1.state)
            await pr1.connected
            return tr1, pr1

        tr1, pr1 = self.loop.run_until_complete(connect_req())

        async def connect_rep():
            tr2, pr2 = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REP,
                connect="inproc://test",
            )
            self.assertEqual("CONNECTED", pr2.state)
            await pr2.connected
            return tr2, pr2

        tr2, pr2 = self.loop.run_until_complete(connect_rep())
        # Without this, this test hangs for some reason.
        tr2._zmq_sock.getsockopt(zmq.EVENTS)

        async def communicate():
            tr1.write([b"request"])
            request = await pr2.received.get()
            self.assertEqual([b"request"], request)
            tr2.write([b"answer"])
            answer = await pr1.received.get()
            self.assertEqual([b"answer"], answer)

        self.loop.run_until_complete(communicate())

        async def closing():
            tr1.close()
            tr2.close()

            await pr1.closed
            self.assertEqual("CLOSED", pr1.state)
            await pr2.closed
            self.assertEqual("CLOSED", pr2.state)

        self.loop.run_until_complete(closing())

    def test_pub_sub(self):
        async def go():
            tr1, pr1, tr2, pr2 = await self.make_pub_sub()
            tr2.setsockopt(zmq.SUBSCRIBE, b"node_id")

            for i in range(5):
                tr1.write([b"node_id", b"publish"])
                try:
                    request = await asyncio.wait_for(pr2.received.get(), 0.1)
                    self.assertEqual([b"node_id", b"publish"], request)
                    break
                except asyncio.TimeoutError:
                    pass
            else:
                raise AssertionError("Cannot get message in subscriber")

            tr1.close()
            tr2.close()
            await pr1.closed
            self.assertEqual("CLOSED", pr1.state)
            await pr2.closed
            self.assertEqual("CLOSED", pr2.state)

        self.loop.run_until_complete(go())

    def test_getsockopt(self):
        port = find_unused_port()

        async def coro():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                bind="tcp://127.0.0.1:{}".format(port),
            )
            await pr.connected
            self.assertEqual(zmq.DEALER, tr.getsockopt(zmq.TYPE))
            return tr, pr

        self.loop.run_until_complete(coro())

    def test_dealer_router(self):
        async def go():
            tr1, pr1, tr2, pr2 = await self.make_dealer_router()
            tr1.write([b"request"])
            request = await pr2.received.get()
            self.assertEqual([mock.ANY, b"request"], request)
            tr2.write([request[0], b"answer"])
            answer = await pr1.received.get()
            self.assertEqual([b"answer"], answer)

            tr1.close()
            tr2.close()

            await pr1.closed
            self.assertEqual("CLOSED", pr1.state)
            await pr2.closed
            self.assertEqual("CLOSED", pr2.state)

        self.loop.run_until_complete(go())

    def test_binds(self):
        port1 = find_unused_port()
        port2 = find_unused_port()
        addr1 = "tcp://127.0.0.1:{}".format(port1)
        addr2 = "tcp://127.0.0.1:{}".format(port2)

        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind=[addr1, addr2],
            )
            await pr.connected

            self.assertEqual({addr1, addr2}, tr.bindings())

            addr3 = await tr.bind("tcp://127.0.0.1:*")
            self.assertEqual({addr1, addr2, addr3}, tr.bindings())
            await tr.unbind(addr2)
            self.assertEqual({addr1, addr3}, tr.bindings())
            self.assertIn(addr1, tr.bindings())
            self.assertRegex(
                repr(tr.bindings()), r"{tcp://127.0.0.1:.\d+, tcp://127.0.0.1:\d+}"
            )
            tr.close()

        self.loop.run_until_complete(connect())

    def test_connects(self):
        port1 = find_unused_port()
        port2 = find_unused_port()
        port3 = find_unused_port()
        addr1 = "tcp://127.0.0.1:{}".format(port1)
        addr2 = "tcp://127.0.0.1:{}".format(port2)
        addr3 = "tcp://127.0.0.1:{}".format(port3)

        async def go():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                connect=[addr1, addr2],
            )
            await pr.connected

            self.assertEqual({addr1, addr2}, tr.connections())
            await tr.connect(addr3)
            self.assertEqual({addr1, addr3, addr2}, tr.connections())
            await tr.disconnect(addr1)
            self.assertEqual({addr2, addr3}, tr.connections())
            tr.close()

        self.loop.run_until_complete(go())

    def test_zmq_socket(self):
        zmq_sock = zmq.Context.instance().socket(zmq.PUB)

        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.PUB, zmq_sock=zmq_sock
            )
            await pr.connected
            return tr, pr

        tr, pr = self.loop.run_until_complete(connect())
        self.assertIs(zmq_sock, tr._zmq_sock)
        self.assertFalse(zmq_sock.closed)
        tr.close()

    def test_zmq_socket_invalid_type(self):
        zmq_sock = zmq.Context.instance().socket(zmq.PUB)

        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.SUB, zmq_sock=zmq_sock
            )
            await pr.connected
            return tr, pr

        with self.assertRaises(ValueError):
            self.loop.run_until_complete(connect())
        self.assertFalse(zmq_sock.closed)

    def test_create_zmq_connection_ZMQError(self):
        zmq_sock = zmq.Context.instance().socket(zmq.PUB)
        zmq_sock.close()

        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.SUB, zmq_sock=zmq_sock
            )
            await pr.connected
            return tr, pr

        with self.assertRaises(OSError) as ctx:
            self.loop.run_until_complete(connect())
        self.assertIn(ctx.exception.errno, (zmq.ENOTSUP, zmq.ENOTSOCK))

    def test_create_zmq_connection_invalid_bind(self):
        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.SUB, bind=2
            )

        with self.assertRaises(ValueError):
            self.loop.run_until_complete(connect())

    def test_create_zmq_connection_invalid_connect(self):
        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.SUB, connect=2
            )

        with self.assertRaises(ValueError):
            self.loop.run_until_complete(connect())

    @unittest.skipIf(sys.platform == "win32", "Windows calls abort() on bad socket")
    def test_create_zmq_connection_closes_socket_on_bad_bind(self):
        async def connect():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.SUB, bind="badaddr"
            )
            await pr.connected
            return tr, pr

        with self.assertRaises(OSError):
            self.loop.run_until_complete(connect())

    @unittest.skipIf(sys.platform == "win32", "Windows calls abort() on bad socket")
    def test_create_zmq_connection_closes_socket_on_bad_connect(self):
        async def connect():
            with self.assertRaises(OSError):
                await aiozmq.create_zmq_connection(
                    lambda: Protocol(self.loop),
                    zmq.SUB,
                    connect="badaddr",
                )

        self.loop.run_until_complete(connect())

    def test_create_zmq_connection_dns_in_connect(self):
        port = find_unused_port()

        async def connect():
            addr = "tcp://*****:*****@mock.patch("aiozmq.core.logger")
    def test_warning_on_connection_lost(self, m_log):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(
            aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind="tcp://127.0.0.1:{}".format(port),
            )
        )
        self.assertEqual(0, tr1._conn_lost)
        tr1.LOG_THRESHOLD_FOR_CONNLOST_WRITES = 2
        tr1.close()
        self.assertEqual(1, tr1._conn_lost)
        tr1.write([b"data"])
        self.assertEqual(2, tr1._conn_lost)
        self.assertFalse(m_log.warning.called)
        tr1.write([b"data"])
        self.assertEqual(3, tr1._conn_lost)
        m_log.warning.assert_called_with("write to closed ZMQ socket.")

    def test_close_on_error(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(
            aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind="tcp://127.0.0.1:{}".format(port),
            )
        )
        handler = mock.Mock()
        self.loop.set_exception_handler(handler)
        sock = tr1.get_extra_info("zmq_socket")
        sock.close()
        tr1.write([b"data"])
        self.assertTrue(tr1._closing)
        handler.assert_called_with(
            self.loop,
            {
                "protocol": pr1,
                "exception": mock.ANY,
                "transport": tr1,
                "message": "Fatal write error on zmq socket transport",
            },
        )
        # expecting 'Socket operation on non-socket'
        if sys.platform == "darwin":
            errno = 38
        else:
            errno = 88
        check_errno(errno, handler.call_args[0][1]["exception"])

    def test_double_force_close(self):
        port = find_unused_port()
        tr1, pr1 = self.loop.run_until_complete(
            aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind="tcp://127.0.0.1:{}".format(port),
            )
        )
        handler = mock.Mock()
        self.loop.set_exception_handler(handler)
        err = RuntimeError("error")
        tr1._fatal_error(err)
        tr1._fatal_error(err)
        self.loop.run_until_complete(pr1.closed)

    def test___repr__(self):
        port = find_unused_port()

        async def coro():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                bind="tcp://127.0.0.1:{}".format(port),
            )
            await pr.connected
            self.assertRegex(
                repr(tr),
                "<ZmqTransport sock=<[^>]+> "
                "type=DEALER read=idle write=<idle, bufsize=0>>",
            )
            tr.close()

        self.loop.run_until_complete(coro())

    def test_extra_zmq_type(self):
        port = find_unused_port()

        async def coro():
            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.DEALER,
                bind="tcp://127.0.0.1:{}".format(port),
            )
            await pr.connected

            self.assertEqual(zmq.DEALER, tr.get_extra_info("zmq_type"))
            tr.close()

        self.loop.run_until_complete(coro())

    @unittest.skipIf(
        zmq.zmq_version_info() < (4,)
        or zmq.pyzmq_version_info()
        < (
            14,
            4,
        ),
        "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4",
    )
    def test_implicit_monitor_disable(self):
        async def go():

            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.DEALER
            )
            await pr.connected

            await tr.enable_monitor()

            tr.close()
            await pr.closed

            self.assertIsNone(tr._monitor)

        self.loop.run_until_complete(go())

    @unittest.skipIf(
        zmq.zmq_version_info() < (4,)
        or zmq.pyzmq_version_info()
        < (
            14,
            4,
        ),
        "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4",
    )
    def test_force_close_monitor(self):
        async def go():

            tr, pr = await aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.DEALER
            )
            await pr.connected

            await tr.enable_monitor()

            tr.abort()
            await pr.closed

            self.assertIsNone(tr._monitor)

        self.loop.run_until_complete(go())
Ejemplo n.º 18
0
    def get_esky_freezer_includes(self):
        # Sometimes the auto module traversal doesn't find everything, so we
        # explicitly add it. The auto dependency tracking especially does not work for
        # imports occurring in salt.modules, as they are loaded at salt runtime.
        # Specifying includes that don't exist doesn't appear to cause a freezing
        # error.
        freezer_includes = [
            'zmq.core.*',
            'zmq.utils.*',
            'ast',
            'difflib',
            'distutils',
            'distutils.version',
            'numbers',
            'json',
            'M2Crypto',
            'Cookie',
            'asyncore',
            'fileinput',
            'sqlite3',
            'email',
            'email.mime.*',
            'requests',
            'sqlite3',
        ]
        if HAS_ZMQ and hasattr(zmq, 'pyzmq_version_info'):
            if HAS_ZMQ and zmq.pyzmq_version_info() >= (0, 14):
                # We're freezing, and when freezing ZMQ needs to be installed, so this
                # works fine
                if 'zmq.core.*' in freezer_includes:
                    # For PyZMQ >= 0.14, freezing does not need 'zmq.core.*'
                    freezer_includes.remove('zmq.core.*')

        if IS_WINDOWS_PLATFORM:
            freezer_includes.extend([
                'win32api',
                'win32file',
                'win32con',
                'win32com',
                'win32net',
                'win32netcon',
                'win32gui',
                'win32security',
                'ntsecuritycon',
                'pywintypes',
                'pythoncom',
                '_winreg',
                'wmi',
                'site',
                'psutil',
            ])
        elif sys.platform.startswith('linux'):
            freezer_includes.append('spwd')
            try:
                import yum  # pylint: disable=unused-variable
                freezer_includes.append('yum')
            except ImportError:
                pass
        elif sys.platform.startswith('sunos'):
            # (The sledgehammer approach)
            # Just try to include everything
            # (This may be a better way to generate freezer_includes generally)
            try:
                from bbfreeze.modulegraph.modulegraph import ModuleGraph
                mgraph = ModuleGraph(sys.path[:])
                for arg in glob.glob('salt/modules/*.py'):
                    mgraph.run_script(arg)
                for mod in mgraph.flatten():
                    if type(mod).__name__ != 'Script' and mod.filename:
                        freezer_includes.append(
                            str(os.path.basename(mod.identifier)))
            except ImportError:
                pass
            # Include C extension that convinces esky to package up the libsodium C library
            # This is needed for ctypes to find it in libnacl which is in turn needed for raet
            # see pkg/smartos/esky/sodium_grabber{.c,_installer.py}
            freezer_includes.extend([
                'sodium_grabber',
                'ioflo',
                'raet',
                'libnacl',
            ])
        return freezer_includes
Ejemplo n.º 19
0
    def get_esky_freezer_includes(self):
        # Sometimes the auto module traversal doesn't find everything, so we
        # explicitly add it. The auto dependency tracking especially does not work for
        # imports occurring in salt.modules, as they are loaded at salt runtime.
        # Specifying includes that don't exist doesn't appear to cause a freezing
        # error.
        freezer_includes = [
            "zmq.core.*",
            "zmq.utils.*",
            "ast",
            "csv",
            "difflib",
            "distutils",
            "distutils.version",
            "numbers",
            "json",
            "M2Crypto",
            "Cookie",
            "asyncore",
            "fileinput",
            "sqlite3",
            "email",
            "email.mime.*",
            "requests",
            "sqlite3",
        ]
        if HAS_ZMQ and hasattr(zmq, "pyzmq_version_info"):
            if HAS_ZMQ and zmq.pyzmq_version_info() >= (0, 14):
                # We're freezing, and when freezing ZMQ needs to be installed, so this
                # works fine
                if "zmq.core.*" in freezer_includes:
                    # For PyZMQ >= 0.14, freezing does not need 'zmq.core.*'
                    freezer_includes.remove("zmq.core.*")

        if IS_WINDOWS_PLATFORM:
            freezer_includes.extend([
                "imp",
                "win32api",
                "win32file",
                "win32con",
                "win32com",
                "win32net",
                "win32netcon",
                "win32gui",
                "win32security",
                "ntsecuritycon",
                "pywintypes",
                "pythoncom",
                "_winreg",
                "wmi",
                "site",
                "psutil",
                "pytz",
            ])
        elif IS_SMARTOS_PLATFORM:
            # we have them as requirements in pkg/smartos/esky/requirements.txt
            # all these should be safe to force include
            freezer_includes.extend([
                "cherrypy", "python-dateutil", "pyghmi", "croniter", "mako",
                "gnupg"
            ])
        elif sys.platform.startswith("linux"):
            freezer_includes.append("spwd")
            try:
                import yum  # pylint: disable=unused-variable

                freezer_includes.append("yum")
            except ImportError:
                pass
        elif sys.platform.startswith("sunos"):
            # (The sledgehammer approach)
            # Just try to include everything
            # (This may be a better way to generate freezer_includes generally)
            try:
                from bbfreeze.modulegraph.modulegraph import ModuleGraph

                mgraph = ModuleGraph(sys.path[:])
                for arg in glob.glob("salt/modules/*.py"):
                    mgraph.run_script(arg)
                for mod in mgraph.flatten():
                    if type(mod).__name__ != "Script" and mod.filename:
                        freezer_includes.append(
                            str(os.path.basename(mod.identifier)))
            except ImportError:
                pass

        return freezer_includes
Ejemplo n.º 20
0
class ZmqStreamTests(unittest.TestCase):
    def setUp(self):
        self.loop = aiozmq.ZmqEventLoop()
        asyncio.set_event_loop(None)

    def tearDown(self):
        self.loop.close()

    def test_req_rep(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s2 = yield from aiozmq.create_zmq_stream(
                zmq.ROUTER,
                connect='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s1.write([b'request'])
            req = yield from s2.read()
            self.assertEqual([mock.ANY, b'request'], req)
            s2.write([req[0], b'answer'])
            answer = yield from s1.read()
            self.assertEqual([b'answer'], answer)

        self.loop.run_until_complete(go())

    def test_closed(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s2 = yield from aiozmq.create_zmq_stream(
                zmq.ROUTER,
                connect='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            self.assertFalse(s2.at_closing())
            s2.close()
            s1.write([b'request'])
            with self.assertRaises(aiozmq.ZmqStreamClosed):
                yield from s2.read()
            self.assertTrue(s2.at_closing())

        self.loop.run_until_complete(go())

    def test_transport(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            self.assertIsInstance(s1.transport, aiozmq.ZmqTransport)
            s1.close()
            with self.assertRaises(aiozmq.ZmqStreamClosed):
                yield from s1.read()
            self.assertIsNone(s1.transport)

        self.loop.run_until_complete(go())

    def test_get_extra_info(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            self.assertIsInstance(s1.get_extra_info('zmq_socket'), zmq.Socket)

        self.loop.run_until_complete(go())

    def test_exception(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            self.assertIsNone(s1.exception())

        self.loop.run_until_complete(go())

    def test_default_loop(self):
        asyncio.set_event_loop(self.loop)

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*')

            s1.close()

        self.loop.run_until_complete(go())

    def test_set_read_buffer_limits1(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            s1.set_read_buffer_limits(low=10)
            self.assertEqual(10, s1._low_water)
            self.assertEqual(40, s1._high_water)

            s1.close()

        self.loop.run_until_complete(go())

    def test_set_read_buffer_limits2(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            s1.set_read_buffer_limits(high=60)
            self.assertEqual(15, s1._low_water)
            self.assertEqual(60, s1._high_water)

            s1.close()

        self.loop.run_until_complete(go())

    def test_set_read_buffer_limits3(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            with self.assertRaises(ValueError):
                s1.set_read_buffer_limits(high=1, low=2)

            s1.close()

        self.loop.run_until_complete(go())

    def test_pause_reading(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s2 = yield from aiozmq.create_zmq_stream(
                zmq.ROUTER,
                connect='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s2.set_read_buffer_limits(high=5)
            s1.write([b'request'])

            yield from asyncio.sleep(0.01, loop=self.loop)
            self.assertTrue(s2._paused)

            msg = yield from s2.read()
            self.assertEqual([mock.ANY, b'request'], msg)
            self.assertFalse(s2._paused)

        self.loop.run_until_complete(go())

    def test_set_exception(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            exc = RuntimeError('some exc')
            s1.set_exception(exc)
            self.assertIs(exc, s1.exception())

            with self.assertRaisesRegex(RuntimeError, 'some exc'):
                yield from s1.read()

        self.loop.run_until_complete(go())

    def test_set_exception_with_waiter(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            def f():
                yield from s1.read()

            t1 = ensure_future(f(), loop=self.loop)
            # to run f() up to yield from
            yield from asyncio.sleep(0.001, loop=self.loop)

            self.assertIsNotNone(s1._waiter)

            exc = RuntimeError('some exc')
            s1.set_exception(exc)
            self.assertIs(exc, s1.exception())

            with self.assertRaisesRegex(RuntimeError, 'some exc'):
                yield from s1.read()

            t1.cancel()

        self.loop.run_until_complete(go())

    def test_set_exception_with_cancelled_waiter(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            def f():
                yield from s1.read()

            t1 = ensure_future(f(), loop=self.loop)
            # to run f() up to yield from
            yield from asyncio.sleep(0.001, loop=self.loop)

            self.assertIsNotNone(s1._waiter)
            t1.cancel()

            exc = RuntimeError('some exc')
            s1.set_exception(exc)
            self.assertIs(exc, s1.exception())

            with self.assertRaisesRegex(RuntimeError, 'some exc'):
                yield from s1.read()

        self.loop.run_until_complete(go())

    def test_double_reading(self):
        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     bind='tcp://127.0.0.1:*',
                                                     loop=self.loop)

            def f():
                yield from s1.read()

            t1 = ensure_future(f(), loop=self.loop)
            # to run f() up to yield from
            yield from asyncio.sleep(0.001, loop=self.loop)

            with self.assertRaises(RuntimeError):
                yield from s1.read()

            t1.cancel()

        self.loop.run_until_complete(go())

    def test_close_on_reading(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            def f():
                yield from s1.read()

            t1 = ensure_future(f(), loop=self.loop)
            # to run f() up to yield from
            yield from asyncio.sleep(0.001, loop=self.loop)

            s1.close()
            yield from asyncio.sleep(0.001, loop=self.loop)

            with self.assertRaises(aiozmq.ZmqStreamClosed):
                t1.result()

        self.loop.run_until_complete(go())

    def test_close_on_cancelled_reading(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            def f():
                yield from s1.read()

            t1 = ensure_future(f(), loop=self.loop)
            # to run f() up to yield from
            yield from asyncio.sleep(0.001, loop=self.loop)

            t1.cancel()
            s1.feed_closing()

            yield from asyncio.sleep(0.001, loop=self.loop)
            with self.assertRaises(asyncio.CancelledError):
                t1.result()

        self.loop.run_until_complete(go())

    def test_feed_cancelled_msg(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            def f():
                yield from s1.read()

            t1 = ensure_future(f(), loop=self.loop)
            # to run f() up to yield from
            yield from asyncio.sleep(0.001, loop=self.loop)

            t1.cancel()
            s1.feed_msg([b'data'])

            yield from asyncio.sleep(0.001, loop=self.loop)
            with self.assertRaises(asyncio.CancelledError):
                t1.result()

            self.assertEqual(4, s1._queue_len)
            self.assertEqual((4, [b'data']), s1._queue.popleft())

        self.loop.run_until_complete(go())

    def test_error_on_read(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.REP,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            handler = mock.Mock()
            self.loop.set_exception_handler(handler)
            s1.write([b'data'])
            with self.assertRaises(OSError) as ctx:
                yield from s1.read()
            check_errno(zmq.EFSM, ctx.exception)
            with self.assertRaises(OSError) as ctx2:
                yield from s1.drain()
            check_errno(zmq.EFSM, ctx2.exception)

        self.loop.run_until_complete(go())

    def test_drain(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.REP,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from s1.drain()

        self.loop.run_until_complete(go())

    def test_pause_resume_connection(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            self.assertFalse(s1._paused)
            s1._protocol.pause_writing()
            self.assertTrue(s1._protocol._paused)
            s1._protocol.resume_writing()
            self.assertFalse(s1._protocol._paused)
            s1.close()

        self.loop.run_until_complete(go())

    def test_resume_paused_with_drain(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            self.assertFalse(s1._paused)
            s1._protocol.pause_writing()

            @asyncio.coroutine
            def f():
                yield from s1.drain()

            fut = ensure_future(f(), loop=self.loop)
            yield from asyncio.sleep(0.01, loop=self.loop)

            self.assertTrue(s1._protocol._paused)
            s1._protocol.resume_writing()
            self.assertFalse(s1._protocol._paused)

            yield from fut

            s1.close()

        self.loop.run_until_complete(go())

    def test_close_paused_connection(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s1._protocol.pause_writing()
            s1.close()

        self.loop.run_until_complete(go())

    def test_close_paused_with_drain(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            self.assertFalse(s1._paused)
            s1._protocol.pause_writing()

            @asyncio.coroutine
            def f():
                yield from s1.drain()

            fut = ensure_future(f(), loop=self.loop)
            yield from asyncio.sleep(0.01, loop=self.loop)

            s1.close()
            yield from fut

        self.loop.run_until_complete(go())

    def test_drain_after_closing(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            s1.close()
            yield from asyncio.sleep(0, loop=self.loop)

            with self.assertRaises(ConnectionResetError):
                yield from s1.drain()

        self.loop.run_until_complete(go())

    def test_exception_after_drain(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s1 = yield from aiozmq.create_zmq_stream(
                zmq.DEALER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            self.assertFalse(s1._paused)
            s1._protocol.pause_writing()

            @asyncio.coroutine
            def f():
                yield from s1.drain()

            fut = ensure_future(f(), loop=self.loop)
            yield from asyncio.sleep(0.01, loop=self.loop)

            exc = RuntimeError("exception")
            s1._protocol.connection_lost(exc)
            with self.assertRaises(RuntimeError) as cm:
                yield from fut
            self.assertIs(cm.exception, exc)

        self.loop.run_until_complete(go())

    def test_double_read_of_closed_stream(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            s2 = yield from aiozmq.create_zmq_stream(
                zmq.ROUTER,
                connect='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)

            self.assertFalse(s2.at_closing())
            s2.close()
            with self.assertRaises(aiozmq.ZmqStreamClosed):
                yield from s2.read()
            self.assertTrue(s2.at_closing())

            with self.assertRaises(aiozmq.ZmqStreamClosed):
                yield from s2.read()
            self.assertTrue(s2.at_closing())

        self.loop.run_until_complete(go())

    @unittest.skipIf(zmq.zmq_version_info() < (4, )
                     or zmq.pyzmq_version_info() < (
                         14,
                         4,
                     ),
                     "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4")
    def test_monitor(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():
            addr = 'tcp://127.0.0.1:{}'.format(port)
            s1 = yield from aiozmq.create_zmq_stream(zmq.ROUTER,
                                                     bind=addr,
                                                     loop=self.loop)

            @asyncio.coroutine
            def f(s, events):
                try:
                    while True:
                        event = yield from s.read_event()
                        events.append(event)
                except aiozmq.ZmqStreamClosed:
                    pass

            s2 = yield from aiozmq.create_zmq_stream(zmq.DEALER,
                                                     loop=self.loop)

            events = []
            t = ensure_future(f(s2, events), loop=self.loop)

            yield from s2.transport.enable_monitor()
            yield from s2.transport.connect(addr)
            yield from s2.transport.disconnect(addr)
            yield from s2.transport.connect(addr)

            s2.write([b'request'])
            req = yield from s1.read()
            self.assertEqual([mock.ANY, b'request'], req)
            s1.write([req[0], b'answer'])
            answer = yield from s2.read()
            self.assertEqual([b'answer'], answer)

            s2.close()
            s1.close()

            yield from t

            # Confirm that the events received by the monitor were valid.
            self.assertGreater(len(events), 0)
            while len(events):
                event = events.pop()
                self.assertIsInstance(event, SocketEvent)
                self.assertIn(event.event, ZMQ_EVENTS)

        self.loop.run_until_complete(go())
Ejemplo n.º 21
0
    'ast',
    'difflib',
    'distutils',
    'distutils.version',
    'numbers',
    'json',
    'M2Crypto',
    'Cookie',
    'asyncore',
    'fileinput',
    'email',
    'email.mime.*',
]

if HAS_ZMQ and hasattr(zmq, 'pyzmq_version_info'):
    if HAS_ZMQ and zmq.pyzmq_version_info() >= (0, 14):
        # We're freezing, and when freezing ZMQ needs to be installed, so this
        # works fine
        if 'zmq.core.*' in FREEZER_INCLUDES:
            # For PyZMQ >= 0.14, freezing does not need 'zmq.core.*'
            FREEZER_INCLUDES.remove('zmq.core.*')

if IS_WINDOWS_PLATFORM:
    FREEZER_INCLUDES.extend([
        'win32api',
        'win32file',
        'win32con',
        'win32com',
        'win32net',
        'win32netcon',
        'win32gui',
Ejemplo n.º 22
0
from threading import Thread

import zmq

if zmq.pyzmq_version_info() >= (17, 0):
    from tornado.ioloop import IOLoop
else:
    # deprecated since pyzmq 17
    from zmq.eventloop.ioloop import IOLoop


class ControlThread(Thread):
    def __init__(self, **kwargs):
        Thread.__init__(self, name="Control", **kwargs)
        self.io_loop = IOLoop(make_current=False)
        self.pydev_do_not_trace = True
        self.is_pydev_daemon_thread = True

    def run(self):
        self.name = "Control"
        self.io_loop.make_current()
        try:
            self.io_loop.start()
        finally:
            self.io_loop.close()

    def stop(self):
        """Stop the thread.

        This method is threadsafe.
        """
Ejemplo n.º 23
0
import textwrap
import curses
import sys

import zmq

print(zmq.zmq_version_info())
print(zmq.pyzmq_version_info())

class server(object):
    def __init__(self):
        self.subs = {}
        self.servers = {}

    def log_print(self, indent, msg):
        try:
            indent = indent.decode()
        except:
            pass
        self.textwrapper.initial_indent = '{:22}'.format(indent)
        self.textwrapper.subsequent_indent = ' '*max(22,len(indent))
        lines = self.textwrapper.wrap(msg)
        for l in lines:
            self.logwin.scroll()
            p = self.logwin.getmaxyx()
            self.logwin.addstr(p[0]-2,1,l)
        self.logwin.border()
        self.logwin.refresh()

    def setup_screen(self):
Ejemplo n.º 24
0
class ZmqSocketMonitorTests(unittest.TestCase):
    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

    def tearDown(self):
        self.loop.close()
        asyncio.set_event_loop(None)

    @unittest.skipIf(zmq.zmq_version_info() < (4, )
                     or zmq.pyzmq_version_info() < (
                         14,
                         4,
                     ),
                     "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4")
    def test_socket_monitor(self):
        port = find_unused_port()

        @asyncio.coroutine
        def go():

            # Create server and bind
            st, sp = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.ROUTER,
                bind='tcp://127.0.0.1:{}'.format(port),
                loop=self.loop)
            yield from sp.wait_ready
            addr = list(st.bindings())[0]

            # Create client but don't connect it yet.
            ct, cp = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.DEALER, loop=self.loop)
            yield from cp.wait_ready

            # Establish an event monitor on the client socket
            yield from ct.enable_monitor()

            # Now that the socket event monitor is established, connect
            # the client to the server which will generate some events.
            yield from ct.connect(addr)
            yield from asyncio.sleep(0.1, loop=self.loop)
            yield from ct.disconnect(addr)
            yield from asyncio.sleep(0.1, loop=self.loop)
            yield from ct.connect(addr)

            # Send a message to the server. The server should respond and
            # this is used to compete the wait_done future.
            ct.write([b'Hello'])
            yield from cp.wait_done

            yield from ct.disable_monitor()

            ct.close()
            yield from cp.wait_closed
            st.close()
            yield from sp.wait_closed

            # Confirm that the events received by the monitor were valid.
            self.assertGreater(cp.events_received.qsize(), 0)
            while not cp.events_received.empty():
                event = yield from cp.events_received.get()
                self.assertIn(event.event, ZMQ_EVENTS)

        self.loop.run_until_complete(go())

    def test_unsupported_dependencies(self):
        @asyncio.coroutine
        def go():

            ct, cp = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.DEALER, loop=self.loop)
            yield from cp.wait_ready

            with unittest.mock.patch.object(zmq,
                                            'zmq_version_info',
                                            return_value=(3, )):
                with self.assertRaises(NotImplementedError):
                    yield from ct.enable_monitor()

            with unittest.mock.patch.object(zmq,
                                            'pyzmq_version_info',
                                            return_value=(14, 3)):
                with self.assertRaises(NotImplementedError):
                    yield from ct.enable_monitor()

            ct.close()
            yield from cp.wait_closed

        self.loop.run_until_complete(go())

    @unittest.skipIf(zmq.zmq_version_info() < (4, )
                     or zmq.pyzmq_version_info() < (
                         14,
                         4,
                     ),
                     "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4")
    def test_double_enable_disable(self):
        @asyncio.coroutine
        def go():

            ct, cp = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop), zmq.DEALER, loop=self.loop)
            yield from cp.wait_ready

            yield from ct.enable_monitor()

            # Enabling the monitor after it is already enabled should not
            # cause an error
            yield from ct.enable_monitor()

            yield from ct.disable_monitor()

            # Disabling the monitor after it is already disabled should not
            # cause an error
            yield from ct.disable_monitor()

            ct.close()
            yield from cp.wait_closed

        self.loop.run_until_complete(go())
Ejemplo n.º 25
0
    # The socket monitor can be explicitly disabled if necessary.
    # yield from ct.disable_monitor()

    # If a socket monitor is left enabled on a socket being closed,
    # the socket monitor will be closed automatically.
    ct.close()
    yield from cp.wait_closed

    st.close()
    yield from sp.wait_closed


def main():
    asyncio.get_event_loop().run_until_complete(go())
    print("DONE")


if __name__ == '__main__':
    # import logging
    # logging.basicConfig(level=logging.DEBUG)

    if (zmq.zmq_version_info() < (4,) or
            zmq.pyzmq_version_info() < (14, 4,)):
        raise NotImplementedError(
            "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4, "
            "have libzmq:{}, pyzmq:{}".format(
                zmq.zmq_version(), zmq.pyzmq_version()))

    main()
Ejemplo n.º 26
0
            [mesg] = data
            msg = json.loads(mesg.decode('utf-8'))
            log.info('Dealer received: {}'.format(msg))
            self.transmit(msg)
        except asyncio.QueueEmpty:
            pass

    def transmit(self, msg):
        if self.transport:
            self.transport.write([json.dumps(msg).encode('utf-8')])
        else:
            log.error('Invalid zmq transport.')


if __name__ == '__main__':
    if (zmq.zmq_version_info() < (4, ) or zmq.pyzmq_version_info() < (
            14,
            4,
    )):
        raise NotImplementedError(
            "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4, "
            "have libzmq:{}, pyzmq:{}".format(zmq.zmq_version(),
                                              zmq.pyzmq_version()))

    log = logging.getLogger("")
    formatter = logging.Formatter("%(asctime)s %(levelname)s " +
                                  "[%(module)s:%(lineno)d] %(message)s")
    # log the things
    log.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
Ejemplo n.º 27
0
from random import randint
import time
import json
import pprint
import encodings

import zmq
from zmq.devices import monitored_queue
from zmq.utils.monitor import recv_monitor_message

print(zmq.zmq_version_info())
print(zmq.pyzmq_version_info())

EVENT_MAP = {}
print("Event names:")
for name in dir(zmq):
    if name.startswith('EVENT_'):
        value = getattr(zmq, name)
        print("%21s : %4i" % (name, value))
        EVENT_MAP[value] = name

class server(object):
    def __init__(self, bind, identity):
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.ROUTER)
        self.socket.identity = identity
        self.socket.bind(bind)
        self.subs = {}
        self.servers = {}
        self.fds = {}
        self.monitor = self.socket.get_monitor_socket()
Ejemplo n.º 28
0
    'ast',
    'difflib',
    'distutils',
    'distutils.version',
    'numbers',
    'json',
    'M2Crypto',
    'Cookie',
    'asyncore',
    'fileinput',
    'email',
    'email.mime.*',
]

if HAS_ZMQ and hasattr(zmq, 'pyzmq_version_info'):
    if HAS_ZMQ and zmq.pyzmq_version_info() >= (0, 14):
        # We're freezing, and when freezing ZMQ needs to be installed, so this
        # works fine
        if 'zmq.core.*' in FREEZER_INCLUDES:
            # For PyZMQ >= 0.14, freezing does not need 'zmq.core.*'
            FREEZER_INCLUDES.remove('zmq.core.*')

if IS_WINDOWS_PLATFORM:
    FREEZER_INCLUDES.extend([
        'win32api',
        'win32file',
        'win32con',
        'win32com',
        'win32net',
        'win32netcon',
        'win32gui',
Ejemplo n.º 29
0
    await cp.wait_done

    # The socket monitor can be explicitly disabled if necessary.
    # await ct.disable_monitor()

    # If a socket monitor is left enabled on a socket being closed,
    # the socket monitor will be closed automatically.
    ct.close()
    await cp.wait_closed

    st.close()
    await sp.wait_closed


def main():
    asyncio.run(go())
    print("DONE")


if __name__ == "__main__":
    # import logging
    # logging.basicConfig(level=logging.DEBUG)

    if zmq.zmq_version_info() < (4, ) or zmq.pyzmq_version_info() < (14, 4):
        raise NotImplementedError(
            "Socket monitor requires libzmq >= 4 and pyzmq >= 14.4, "
            "have libzmq:{}, pyzmq:{}".format(zmq.zmq_version(),
                                              zmq.pyzmq_version()))

    main()