예제 #1
0
파일: render.py 프로젝트: deev/universe
 def __init__(self, *args, **kwargs):
     if platform.isLinux() and not os.environ.get('DISPLAY'):
         self.renderable = False
     else:
         self.renderable = True
     self._observation = None
     super(Render, self).__init__(*args, **kwargs)
예제 #2
0
파일: render.py 프로젝트: zmoon111/universe
 def __init__(self, *args, **kwargs):
     if platform.isLinux() and not os.environ.get('DISPLAY'):
         self.renderable = False
     else:
         self.renderable = True
     self._observation = None
     super(Render, self).__init__(*args, **kwargs)
예제 #3
0
def _getInstallFunction(platform):
    """
    Return a function to install the reactor most suited for the given platform.

    @param platform: The platform for which to select a reactor.
    @type platform: L{twisted.python.runtime.Platform}

    @return: A zero-argument callable which will install the selected
        reactor.
    """
    # Linux: epoll(7) is the default, since it scales well.
    #
    # macOS: poll(2) is not exposed by Python because it doesn't support all
    # file descriptors (in particular, lack of PTY support is a problem) --
    # see <http://bugs.python.org/issue5154>. kqueue has the same restrictions
    # as poll(2) as far PTY support goes.
    #
    # Windows: IOCP should eventually be default, but still has some serious
    # bugs, e.g. <http://twistedmatrix.com/trac/ticket/4667>.
    #
    # We therefore choose epoll(7) on Linux, poll(2) on other non-macOS POSIX
    # platforms, and select(2) everywhere else.
    try:
        if platform.isLinux():
            try:
                from twisted.internet.epollreactor import install
            except ImportError:
                from twisted.internet.pollreactor import install
        elif platform.getType() == 'posix' and not platform.isMacOSX():
            from twisted.internet.pollreactor import install
        else:
            from twisted.internet.selectreactor import install
    except ImportError:
        from twisted.internet.selectreactor import install
    return install
예제 #4
0
def _machine_id():
    """
    for informational purposes, try to get a machine unique id thing
    """
    if platform.isLinux():
        try:
            # why this? see: http://0pointer.de/blog/projects/ids.html
            with open('/var/lib/dbus/machine-id', 'r') as f:
                return f.read().strip()
        except:
            # Non-dbus using Linux, get a hostname
            return socket.gethostname()

    elif platform.isMacOSX():
        # Get the serial number of the platform
        import plistlib
        plist_data = subprocess.check_output(
            ["ioreg", "-rd1", "-c", "IOPlatformExpertDevice", "-a"])

        if six.PY2:
            # Only API on 2.7
            return plistlib.readPlistFromString(
                plist_data)["IOPlatformSerialNumber"]
        else:
            # New, non-deprecated 3.4+ API
            return plistlib.loads(plist_data)[0]["IOPlatformSerialNumber"]

    else:
        # Something else, just get a hostname
        return socket.gethostname()
예제 #5
0
class UNIXDatagramTestsBuilder(UNIXFamilyMixin, ReactorBuilder):
    """
    Builder defining tests relating to L{IReactorUNIXDatagram}.
    """
    requiredInterfaces = (interfaces.IReactorUNIXDatagram, )

    # There's no corresponding test_connectMode because the mode parameter to
    # connectUNIXDatagram has been completely ignored since that API was first
    # introduced.
    def test_listenMode(self):
        """
        The UNIX socket created by L{IReactorUNIXDatagram.listenUNIXDatagram}
        is created with the mode specified.
        """
        self._modeTest('listenUNIXDatagram', self.mktemp(), DatagramProtocol())

    def test_listenOnLinuxAbstractNamespace(self):
        """
        On Linux, a UNIX socket path may begin with C{'\0'} to indicate a socket
        in the abstract namespace.  L{IReactorUNIX.listenUNIXDatagram} accepts
        such a path.
        """
        path = _abstractPath(self)
        reactor = self.buildReactor()
        port = reactor.listenUNIXDatagram('\0' + path, DatagramProtocol())
        self.assertEqual(port.getHost(), UNIXAddress('\0' + path))

    if not platform.isLinux():
        test_listenOnLinuxAbstractNamespace.skip = (
            'Abstract namespace UNIX sockets only supported on Linux.')
예제 #6
0
파일: default.py 프로젝트: 0004c/VTK
def _getInstallFunction(platform):
    """
    Return a function to install the reactor most suited for the given platform.

    @param platform: The platform for which to select a reactor.
    @type platform: L{twisted.python.runtime.Platform}

    @return: A zero-argument callable which will install the selected
        reactor.
    """
    # Linux: epoll(7) is the default, since it scales well.
    #
    # OS X: poll(2) is not exposed by Python because it doesn't support all
    # file descriptors (in particular, lack of PTY support is a problem) --
    # see <http://bugs.python.org/issue5154>. kqueue has the same restrictions
    # as poll(2) as far PTY support goes.
    #
    # Windows: IOCP should eventually be default, but still has some serious
    # bugs, e.g. <http://twistedmatrix.com/trac/ticket/4667>.
    #
    # We therefore choose epoll(7) on Linux, poll(2) on other non-OS X POSIX
    # platforms, and select(2) everywhere else.
    try:
        if platform.isLinux():
            try:
                from twisted.internet.epollreactor import install
            except ImportError:
                from twisted.internet.pollreactor import install
        elif platform.getType() == 'posix' and not platform.isMacOSX():
            from twisted.internet.pollreactor import install
        else:
            from twisted.internet.selectreactor import install
    except ImportError:
        from twisted.internet.selectreactor import install
    return install
예제 #7
0
class AgentServiceLoopTests(TestCase):
    """
    Tests for ``AgentService.get_loop_service``.
    """
    def setUp(self):
        super(AgentServiceLoopTests, self).setUp()
        agent_service_setup(self)

    @skipUnless(platform.isLinux(), "get_era() only supports Linux.")
    def test_agentloopservice(self):
        """
        ```AgentService.get_loop_service`` returns an ``AgentLoopService``
        using the given deployer and all of the configuration supplied to
        the``AgentService``.
        """
        deployer = object()
        loop_service = self.agent_service.get_loop_service(deployer)
        context_factory = self.agent_service.get_tls_context().context_factory
        self.assertEqual(
            AgentLoopService(reactor=self.reactor,
                             deployer=deployer,
                             host=self.host,
                             port=self.port,
                             context_factory=context_factory,
                             era=get_era()),
            loop_service,
        )
예제 #8
0
파일: node.py 프로젝트: NinjaMSP/crossbar
def _machine_id():
    """
    for informational purposes, try to get a machine unique id thing
    """
    if platform.isLinux():
        try:
            # why this? see: http://0pointer.de/blog/projects/ids.html
            with open('/var/lib/dbus/machine-id', 'r') as f:
                return f.read().strip()
        except:
            # Non-dbus using Linux, get a hostname
            return socket.gethostname()

    elif platform.isMacOSX():
        # Get the serial number of the platform
        import plistlib
        plist_data = subprocess.check_output(["ioreg", "-rd1", "-c", "IOPlatformExpertDevice", "-a"])

        if six.PY2:
            # Only API on 2.7
            return plistlib.readPlistFromString(plist_data)[0]["IOPlatformSerialNumber"]
        else:
            # New, non-deprecated 3.4+ API
            return plistlib.loads(plist_data)[0]["IOPlatformSerialNumber"]

    else:
        # Something else, just get a hostname
        return socket.gethostname()
예제 #9
0
def getreciever(addr):
    if platform.isLinux():
        try:
            return epolludprecieve(addr)
        except:
            return polludprecieve(addr)

    elif platform.getType() == 'posix' and not platform.isMacOSX():
        return polludprecieve(addr)

    elif platform.isMacOSX():
        try:
            return kqueueudprecieve(addr)
        except:
            return polludprecieve(addr)
    else:
        return polludprecieve(addr)
예제 #10
0
 def __eq__(self, other):
     """
     Overriding C{FancyEqMixin} to ensure the os level samefile
     check is done if the name attributes do not match.
     """
     res = super(UNIXAddress, self).__eq__(other)
     if not res and self.name and other.name:
         try:
             return os.path.samefile(self.name, other.name)
         except OSError:
             pass
         except (TypeError, ValueError) as e:
             # On Linux, abstract namespace UNIX sockets start with a
             # \0, which os.path doesn't like.
             if not _PY3 and not platform.isLinux():
                 raise e
     return res
예제 #11
0
파일: address.py 프로젝트: zerospam/twisted
 def __eq__(self, other: object) -> bool:
     """
     Overriding C{attrs} to ensure the os level samefile
     check is done if the name attributes do not match.
     """
     if not isinstance(other, self.__class__):
         return NotImplemented
     res = self.name == other.name
     if not res and self.name and other.name:
         try:
             return os.path.samefile(self.name, other.name)
         except OSError:
             pass
         except (TypeError, ValueError) as e:
             # On Linux, abstract namespace UNIX sockets start with a
             # \0, which os.path doesn't like.
             if not platform.isLinux():
                 raise e
     return res
예제 #12
0
class FlockerNodeEraTests(make_script_tests(EXECUTABLE)):
    """
    Tests for ``flocker-node-era``.
    """
    @skipUnless(which(EXECUTABLE), EXECUTABLE + " not installed")
    @skipUnless(platform.isLinux(), "flocker-node-era only works on Linux")
    def setUp(self):
        super(FlockerNodeEraTests, self).setUp()

    def test_output(self):
        """
        The process outputs the same information as ``get_era()``.
        """
        self.assertEqual(check_output(EXECUTABLE), str(get_era()))

    def test_repeatable_output(self):
        """
        The process outputs the same information when called multiple times,
        since it shoudl only change on reboot.
        """
        self.assertEqual(check_output(EXECUTABLE), check_output(EXECUTABLE))
예제 #13
0
def _get_reactor(platform):
    try:
        if platform.isLinux():
            try:
                from twisted.internet import epollreactor
                cls = epollreactor.EPollReactor
            except ImportError:
                from twisted.internet import pollreactor
                cls = pollreactor.PollReactor
        elif platform.isMacOSX():
            from twisted.internet import kqreactor
            cls = kqreactor.KQueueReactor
        elif platform.getType() == 'posix' and not platform.isMacOSX():
            from twisted.internet import pollreactor
            cls = pollreactor.PollReactor
        else:
            from twisted.internet import selectreactor
            cls = selectreactor.SelectReactor
    except ImportError:
        from twisted.internet import selectreactor
        cls = selectreactor.SelectReactor
    return cls()
예제 #14
0
파일: twisty.py 프로젝트: deev/universe
def _get_reactor(platform):
    try:
        if platform.isLinux():
            try:
                from twisted.internet import epollreactor
                cls = epollreactor.EPollReactor
            except ImportError:
                from twisted.internet import pollreactor
                cls = pollreactor.PollReactor
        elif platform.isMacOSX():
            from twisted.internet import kqreactor
            cls = kqreactor.KQueueReactor
        elif platform.getType() == 'posix' and not platform.isMacOSX():
            from twisted.internet import pollreactor
            cls = pollreactor.PollReactor
        else:
            from twisted.internet import selectreactor
            cls = selectreactor.SelectReactor
    except ImportError:
        from twisted.internet import selectreactor
        cls = selectreactor.SelectReactor
    return cls()
예제 #15
0
def machine_id() -> str:
    """
    For informational purposes, get a unique ID or serial for this machine (device).

    :returns: Unique machine (device) ID (serial), e.g. ``81655b901e334fc1ad59cbf2719806b7``.
    """
    from twisted.python.runtime import platform

    if platform.isLinux():
        try:
            # why this? see: http://0pointer.de/blog/projects/ids.html
            with open('/var/lib/dbus/machine-id', 'r') as f:
                return f.read().strip()
        except:
            # Non-dbus using Linux, get a hostname
            return socket.gethostname()
    elif platform.isMacOSX():
        import plistlib
        plist_data = subprocess.check_output(
            ["ioreg", "-rd1", "-c", "IOPlatformExpertDevice", "-a"])
        return plistlib.loads(plist_data)[0]["IOPlatformSerialNumber"]
    else:
        return socket.gethostname()
예제 #16
0
파일: test_era.py 프로젝트: gooops/flocker
class EraTests(SynchronousTestCase):
    """
    Tests for ``get_era``
    """
    @skipUnless(platform.isLinux(), "get_era() only supported on Linux.")
    def setUp(self):
        pass

    def test_get_era(self):
        """
        The era is the current unique ``boot_id``.

        This rather duplicates the implementation, but can't do much
        better.
        """
        with open("/proc/sys/kernel/random/boot_id") as f:
            self.assertEqual(get_era(), UUID(hex=f.read().strip()))

    def test_repeated(self):
        """
        Repeated calls give the same result.
        """
        values = set(get_era() for i in range(100))
        self.assertEqual(len(values), 1)
예제 #17
0
    nonUNIXSkip = True
else:
    nonUNIXSkip = False

from unittest import skipIf

from twisted.internet import reactor
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.error import ProcessDone
from twisted.internet.protocol import ProcessProtocol
from twisted.python.filepath import FilePath
from twisted.python.runtime import platform

from twisted.trial.unittest import TestCase

if platform.isLinux():
    from socket import MSG_DONTWAIT

    dontWaitSkip = False
else:
    # It would be nice to be able to test flags on more platforms, but finding
    # a flag that works *at all* is somewhat challenging.
    dontWaitSkip = True

try:
    from twisted.python.sendmsg import sendmsg, recvmsg
    from twisted.python.sendmsg import SCM_RIGHTS, getSocketFamily
except ImportError:
    doImportSkip = True
    importSkipReason = "Platform doesn't support sendmsg."
else:
예제 #18
0
class AgentServiceFactoryTests(TestCase):
    """
    Tests for ``AgentServiceFactory``.
    """
    def setUp(self):
        super(AgentServiceFactoryTests, self).setUp()
        setup_config(self)

    def service_factory(self, deployer_factory):
        """
        Create a new ``AgentServiceFactory`` suitable for unit-testing.

        :param deployer_factory: ``deployer_factory`` to use.

        :return: ``AgentServiceFactory`` instance.
        """
        return AgentServiceFactory(
            deployer_factory=deployer_factory,
            get_external_ip=lambda host, port: u"127.0.0.1")

    def test_uuids_from_certificate(self):
        """
        The created deployer got its node UUID and cluster UUID from the given
        node certificate.
        """
        result = []

        def factory(hostname, node_uuid, cluster_uuid):
            result.append((node_uuid, cluster_uuid))
            return object()

        options = DatasetAgentOptions()
        options.parseOptions([b"--agent-config", self.config.path])
        service_factory = self.service_factory(deployer_factory=factory)
        service_factory.get_service(MemoryCoreReactor(), options)
        self.assertEqual(
            (self.ca_set.node.uuid, self.ca_set.node.cluster_uuid), result[0])

    @skipUnless(platform.isLinux(), "get_era() only supports Linux.")
    def test_get_service(self):
        """
        ``AgentServiceFactory.get_service`` creates an ``AgentLoopService``
        configured with the destination given in the config file given by the
        options.
        """
        reactor = MemoryCoreReactor()
        options = DatasetAgentOptions()
        options.parseOptions([b"--agent-config", self.config.path])
        service_factory = self.service_factory(
            deployer_factory=deployer_factory_stub, )
        self.assertEqual(
            AgentLoopService(
                reactor=reactor,
                deployer=deployer,
                host=b"10.0.0.1",
                port=1234,
                context_factory=_context_factory_and_credential(
                    self.config.parent(), b"10.0.0.1", 1234).context_factory,
                era=get_era(),
            ), service_factory.get_service(reactor, options))

    @skipUnless(platform.isLinux(), "get_era() only supports Linux.")
    def test_default_port(self):
        """
        ``AgentServiceFactory.get_service`` creates an ``AgentLoopService``
        configured with port 4524 if no port is specified.
        """
        self.config.setContent(
            yaml.safe_dump({
                u"control-service": {
                    u"hostname": u"10.0.0.2",
                },
                u"dataset": {
                    u"backend": u"zfs",
                },
                u"version": 1,
            }))

        reactor = MemoryCoreReactor()
        options = DatasetAgentOptions()
        options.parseOptions([b"--agent-config", self.config.path])
        service_factory = self.service_factory(
            deployer_factory=deployer_factory_stub, )
        self.assertEqual(
            AgentLoopService(
                reactor=reactor,
                deployer=deployer,
                host=b"10.0.0.2",
                port=4524,
                context_factory=_context_factory_and_credential(
                    self.config.parent(), b"10.0.0.2", 4524).context_factory,
                era=get_era(),
            ), service_factory.get_service(reactor, options))

    def test_config_validated(self):
        """
        ``AgentServiceFactory.get_service`` validates the configuration file.
        """
        self.config.setContent("INVALID")
        reactor = MemoryCoreReactor()
        options = DatasetAgentOptions()
        options.parseOptions([b"--agent-config", self.config.path])
        service_factory = self.service_factory(
            deployer_factory=deployer_factory_stub, )

        self.assertRaises(
            ValidationError,
            service_factory.get_service,
            reactor,
            options,
        )

    def test_deployer_factory_called_with_ip(self):
        """
        ``AgentServiceFactory.main`` calls its ``deployer_factory`` with one
        of the node's IPs.
        """
        spied = []

        def deployer_factory(node_uuid, hostname, cluster_uuid):
            spied.append(IPAddress(hostname))
            return object()

        reactor = MemoryCoreReactor()
        options = DatasetAgentOptions()
        options.parseOptions([b"--agent-config", self.config.path])
        agent = self.service_factory(deployer_factory=deployer_factory)
        agent.get_service(reactor, options)
        self.assertIn(spied[0], get_all_ips())

    def test_missing_configuration_file(self):
        """
        ``AgentServiceFactory.get_service`` raises an ``IOError`` if the given
        configuration file does not exist.
        """
        reactor = MemoryCoreReactor()
        options = DatasetAgentOptions()
        options.parseOptions(
            [b"--agent-config",
             self.make_temporary_path().path])
        service_factory = self.service_factory(
            deployer_factory=deployer_factory_stub, )

        self.assertRaises(
            IOError,
            service_factory.get_service,
            reactor,
            options,
        )
예제 #19
0
파일: _era.py 프로젝트: sysuwbs/flocker
 def main(self, reactor, options):
     if not platform.isLinux():
         raise SystemExit("flocker-node-era only works on Linux.")
     sys.stdout.write(str(get_era()))
     sys.stdout.flush()
     return succeed(None)
예제 #20
0
 def main(self, reactor, options):
     if not platform.isLinux():
         raise SystemExit("flocker-node-era only works on Linux.")
     sys.stdout.write(str(get_era()))
     sys.stdout.flush()
     return succeed(None)
except ImportError:
    nonUNIXSkip = "Platform does not support AF_UNIX sockets"
else:
    nonUNIXSkip = None

from twisted.internet import reactor
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.internet.error import ProcessDone
from twisted.internet.protocol import ProcessProtocol
from twisted.python.compat import _PY3, intToBytes, bytesEnviron
from twisted.python.filepath import FilePath
from twisted.python.runtime import platform

from twisted.trial.unittest import TestCase

if platform.isLinux():
    from socket import MSG_DONTWAIT
    dontWaitSkip = None
else:
    # It would be nice to be able to test flags on more platforms, but finding
    # a flag that works *at all* is somewhat challenging.
    dontWaitSkip = "MSG_DONTWAIT is only known to work as intended on Linux"


try:
    from twisted.python.sendmsg import sendmsg, recvmsg
    from twisted.python.sendmsg import SCM_RIGHTS, getSocketFamily
except ImportError:
    importSkip = "Platform doesn't support sendmsg."
else:
    importSkip = None
예제 #22
0
class UNIXTestsBuilder(UNIXFamilyMixin, ReactorBuilder, ConnectionTestsMixin):
    """
    Builder defining tests relating to L{IReactorUNIX}.
    """
    requiredInterfaces = (IReactorUNIX, )

    endpoints = UNIXCreator()

    def test_mode(self):
        """
        The UNIX socket created by L{IReactorUNIX.listenUNIX} is created with
        the mode specified.
        """
        self._modeTest('listenUNIX', self.mktemp(), ServerFactory())

    def test_listenOnLinuxAbstractNamespace(self):
        """
        On Linux, a UNIX socket path may begin with C{'\0'} to indicate a socket
        in the abstract namespace.  L{IReactorUNIX.listenUNIX} accepts such a
        path.
        """
        # Don't listen on a path longer than the maximum allowed.
        path = _abstractPath(self)
        reactor = self.buildReactor()
        port = reactor.listenUNIX('\0' + path, ServerFactory())
        self.assertEqual(port.getHost(), UNIXAddress('\0' + path))

    if not platform.isLinux():
        test_listenOnLinuxAbstractNamespace.skip = (
            'Abstract namespace UNIX sockets only supported on Linux.')

    def test_listenFailure(self):
        """
        L{IReactorUNIX.listenUNIX} raises L{CannotListenError} if the
        underlying port's createInternetSocket raises a socket error.
        """
        def raiseSocketError(self):
            raise error('FakeBasePort forced socket.error')

        self.patch(base.BasePort, "createInternetSocket", raiseSocketError)
        reactor = self.buildReactor()
        with self.assertRaises(CannotListenError):
            reactor.listenUNIX('not-used', ServerFactory())

    def test_connectToLinuxAbstractNamespace(self):
        """
        L{IReactorUNIX.connectUNIX} also accepts a Linux abstract namespace
        path.
        """
        path = _abstractPath(self)
        reactor = self.buildReactor()
        connector = reactor.connectUNIX('\0' + path, ClientFactory())
        self.assertEqual(connector.getDestination(), UNIXAddress('\0' + path))

    if not platform.isLinux():
        test_connectToLinuxAbstractNamespace.skip = (
            'Abstract namespace UNIX sockets only supported on Linux.')

    def test_addresses(self):
        """
        A client's transport's C{getHost} and C{getPeer} return L{UNIXAddress}
        instances which have the filesystem path of the host and peer ends of
        the connection.
        """
        class SaveAddress(ConnectableProtocol):
            def makeConnection(self, transport):
                self.addresses = dict(host=transport.getHost(),
                                      peer=transport.getPeer())
                transport.loseConnection()

        server = SaveAddress()
        client = SaveAddress()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertEqual(server.addresses['host'], client.addresses['peer'])
        self.assertEqual(server.addresses['peer'], client.addresses['host'])

    def test_sendFileDescriptor(self):
        """
        L{IUNIXTransport.sendFileDescriptor} accepts an integer file descriptor
        and sends a copy of it to the process reading from the connection.
        """
        from socket import fromfd

        s = socket()
        s.bind(('', 0))
        server = SendFileDescriptor(s.fileno(), b"junk")

        client = ReceiveFileDescriptor()
        d = client.waitForDescriptor()

        def checkDescriptor(descriptor):
            received = fromfd(descriptor, AF_INET, SOCK_STREAM)
            # Thanks for the free dup, fromfd()
            close(descriptor)

            # If the sockets have the same local address, they're probably the
            # same.
            self.assertEqual(s.getsockname(), received.getsockname())

            # But it would be cheating for them to be identified by the same
            # file descriptor.  The point was to get a copy, as we might get if
            # there were two processes involved here.
            self.assertNotEqual(s.fileno(), received.fileno())

        d.addCallback(checkDescriptor)
        d.addErrback(err, "Sending file descriptor encountered a problem")
        d.addBoth(lambda ignored: server.transport.loseConnection())

        runProtocolsWithReactor(self, server, client, self.endpoints)

    if sendmsgSkip is not None:
        test_sendFileDescriptor.skip = sendmsgSkip

    def test_sendFileDescriptorTriggersPauseProducing(self):
        """
        If a L{IUNIXTransport.sendFileDescriptor} call fills up the send buffer,
        any registered producer is paused.
        """
        class DoesNotRead(ConnectableProtocol):
            def connectionMade(self):
                self.transport.pauseProducing()

        class SendsManyFileDescriptors(ConnectableProtocol):
            paused = False

            def connectionMade(self):
                self.socket = socket()
                self.transport.registerProducer(self, True)

                def sender():
                    self.transport.sendFileDescriptor(self.socket.fileno())
                    self.transport.write(b"x")

                self.task = LoopingCall(sender)
                self.task.clock = self.transport.reactor
                self.task.start(0).addErrback(err, "Send loop failure")

            def stopProducing(self):
                self._disconnect()

            def resumeProducing(self):
                self._disconnect()

            def pauseProducing(self):
                self.paused = True
                self.transport.unregisterProducer()
                self._disconnect()

            def _disconnect(self):
                self.task.stop()
                self.transport.abortConnection()
                self.other.transport.abortConnection()

        server = SendsManyFileDescriptors()
        client = DoesNotRead()
        server.other = client
        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertTrue(server.paused,
                        "sendFileDescriptor producer was not paused")

    if sendmsgSkip is not None:
        test_sendFileDescriptorTriggersPauseProducing.skip = sendmsgSkip

    def test_fileDescriptorOverrun(self):
        """
        If L{IUNIXTransport.sendFileDescriptor} is used to queue a greater
        number of file descriptors than the number of bytes sent using
        L{ITransport.write}, the connection is closed and the protocol connected
        to the transport has its C{connectionLost} method called with a failure
        wrapping L{FileDescriptorOverrun}.
        """
        cargo = socket()
        server = SendFileDescriptor(cargo.fileno(), None)

        client = ReceiveFileDescriptor()
        result = []
        d = client.waitForDescriptor()
        d.addBoth(result.append)
        d.addBoth(lambda ignored: server.transport.loseConnection())

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertIsInstance(result[0], Failure)
        result[0].trap(ConnectionClosed)
        self.assertIsInstance(server.reason.value, FileDescriptorOverrun)

    if sendmsgSkip is not None:
        test_fileDescriptorOverrun.skip = sendmsgSkip

    def _sendmsgMixinFileDescriptorReceivedDriver(self, ancillaryPacker):
        """
        Drive _SendmsgMixin via sendmsg socket calls to check that
        L{IFileDescriptorReceiver.fileDescriptorReceived} is called once
        for each file descriptor received in the ancillary messages.

        @param ancillaryPacker: A callable that will be given a list of
            two file descriptors and should return a two-tuple where:
            The first item is an iterable of zero or more (cmsg_level,
            cmsg_type, cmsg_data) tuples in the same order as the given
            list for actual sending via sendmsg; the second item is an
            integer indicating the expected number of FDs to be received.
        """
        # Strategy:
        # - Create a UNIX socketpair.
        # - Associate one end to a FakeReceiver and FakeProtocol.
        # - Call sendmsg on the other end to send FDs as ancillary data.
        #   Ancillary data is obtained calling ancillaryPacker with
        #   the two FDs associated to two temp files (using the socket
        #   FDs for this fails the device/inode verification tests on
        #   macOS 10.10, so temp files are used instead).
        # - Call doRead in the FakeReceiver.
        # - Verify results on FakeProtocol.
        #   Using known device/inodes to verify correct order.

        # TODO: replace FakeReceiver test approach with one based in
        # IReactorSocket.adoptStreamConnection once AF_UNIX support is
        # implemented; see https://twistedmatrix.com/trac/ticket/5573.

        from socket import socketpair
        from twisted.internet.unix import _SendmsgMixin
        from twisted.python.sendmsg import sendmsg

        def deviceInodeTuple(fd):
            fs = fstat(fd)
            return (fs.st_dev, fs.st_ino)

        @implementer(IFileDescriptorReceiver)
        class FakeProtocol(ConnectableProtocol):
            def __init__(self):
                self.fds = []
                self.deviceInodesReceived = []

            def fileDescriptorReceived(self, fd):
                self.fds.append(fd)
                self.deviceInodesReceived.append(deviceInodeTuple(fd))
                close(fd)

        class FakeReceiver(_SendmsgMixin):
            bufferSize = 1024

            def __init__(self, skt, proto):
                self.socket = skt
                self.protocol = proto

            def _dataReceived(self, data):
                pass

            def getHost(self):
                pass

            def getPeer(self):
                pass

            def _getLogPrefix(self, o):
                pass

        sendSocket, recvSocket = socketpair(AF_UNIX, SOCK_STREAM)
        self.addCleanup(sendSocket.close)
        self.addCleanup(recvSocket.close)

        proto = FakeProtocol()
        receiver = FakeReceiver(recvSocket, proto)

        # Temp files give us two FDs to send/receive/verify.
        fileOneFD, fileOneName = mkstemp()
        fileTwoFD, fileTwoName = mkstemp()
        self.addCleanup(unlink, fileOneName)
        self.addCleanup(unlink, fileTwoName)

        dataToSend = b'some data needs to be sent'
        fdsToSend = [fileOneFD, fileTwoFD]
        ancillary, expectedCount = ancillaryPacker(fdsToSend)
        sendmsg(sendSocket, dataToSend, ancillary)

        receiver.doRead()

        # Verify that fileDescriptorReceived was called twice.
        self.assertEqual(len(proto.fds), expectedCount)

        # Verify that received FDs are different from the sent ones.
        self.assertFalse(set(fdsToSend).intersection(set(proto.fds)))

        # Verify that FDs were received in the same order, if any.
        if proto.fds:
            deviceInodesSent = [deviceInodeTuple(fd) for fd in fdsToSend]
            self.assertEqual(deviceInodesSent, proto.deviceInodesReceived)

    def test_multiFileDescriptorReceivedPerRecvmsgOneCMSG(self):
        """
        _SendmsgMixin handles multiple file descriptors per recvmsg, calling
        L{IFileDescriptorReceiver.fileDescriptorReceived} once per received
        file descriptor. Scenario: single CMSG with two FDs.
        """
        from twisted.python.sendmsg import SCM_RIGHTS

        def ancillaryPacker(fdsToSend):
            ancillary = [(SOL_SOCKET, SCM_RIGHTS, pack('ii', *fdsToSend))]
            expectedCount = 2
            return ancillary, expectedCount

        self._sendmsgMixinFileDescriptorReceivedDriver(ancillaryPacker)

    if sendmsgSkip is not None:
        test_multiFileDescriptorReceivedPerRecvmsgOneCMSG.skip = sendmsgSkip

    def test_multiFileDescriptorReceivedPerRecvmsgTwoCMSGs(self):
        """
        _SendmsgMixin handles multiple file descriptors per recvmsg, calling
        L{IFileDescriptorReceiver.fileDescriptorReceived} once per received
        file descriptor. Scenario: two CMSGs with one FD each.
        """
        from twisted.python.sendmsg import SCM_RIGHTS

        def ancillaryPacker(fdsToSend):
            ancillary = [(SOL_SOCKET, SCM_RIGHTS, pack('i', fd))
                         for fd in fdsToSend]
            expectedCount = 2
            return ancillary, expectedCount

        self._sendmsgMixinFileDescriptorReceivedDriver(ancillaryPacker)

    if platform.isMacOSX():
        test_multiFileDescriptorReceivedPerRecvmsgTwoCMSGs.skip = (
            "Multi control message ancillary sendmsg not supported on Mac.")
    elif sendmsgSkip is not None:
        test_multiFileDescriptorReceivedPerRecvmsgTwoCMSGs.skip = sendmsgSkip

    def test_multiFileDescriptorReceivedPerRecvmsgBadCMSG(self):
        """
        _SendmsgMixin handles multiple file descriptors per recvmsg, calling
        L{IFileDescriptorReceiver.fileDescriptorReceived} once per received
        file descriptor. Scenario: unsupported CMSGs.
        """
        # Given that we can't just send random/invalid ancillary data via the
        # packer for it to be sent via sendmsg -- the kernel would not accept
        # it -- we'll temporarily replace recvmsg with a fake one that produces
        # a non-supported ancillary message level/type. This being said, from
        # the perspective of the ancillaryPacker, all that is required is to
        # let the test driver know that 0 file descriptors are expected.
        from twisted.python import sendmsg

        def ancillaryPacker(fdsToSend):
            ancillary = []
            expectedCount = 0
            return ancillary, expectedCount

        def fakeRecvmsgUnsupportedAncillary(skt, *args, **kwargs):
            data = b'some data'
            ancillary = [(None, None, b'')]
            flags = 0
            return sendmsg.RecievedMessage(data, ancillary, flags)

        events = []
        addObserver(events.append)
        self.addCleanup(removeObserver, events.append)

        self.patch(sendmsg, "recvmsg", fakeRecvmsgUnsupportedAncillary)
        self._sendmsgMixinFileDescriptorReceivedDriver(ancillaryPacker)

        # Verify the expected message was logged.
        expectedMessage = 'received unsupported ancillary data'
        found = any(expectedMessage in e['format'] for e in events)
        self.assertTrue(found, 'Expected message not found in logged events')

    if sendmsgSkip is not None:
        test_multiFileDescriptorReceivedPerRecvmsgBadCMSG.skip = sendmsgSkip

    def test_avoidLeakingFileDescriptors(self):
        """
        If associated with a protocol which does not provide
        L{IFileDescriptorReceiver}, file descriptors received by the
        L{IUNIXTransport} implementation are closed and a warning is emitted.
        """
        # To verify this, establish a connection.  Send one end of the
        # connection over the IUNIXTransport implementation.  After the copy
        # should no longer exist, close the original.  If the opposite end of
        # the connection decides the connection is closed, the copy does not
        # exist.
        from socket import socketpair
        probeClient, probeServer = socketpair()

        events = []
        addObserver(events.append)
        self.addCleanup(removeObserver, events.append)

        class RecordEndpointAddresses(SendFileDescriptor):
            def connectionMade(self):
                self.hostAddress = self.transport.getHost()
                self.peerAddress = self.transport.getPeer()
                SendFileDescriptor.connectionMade(self)

        server = RecordEndpointAddresses(probeClient.fileno(), b"junk")
        client = ConnectableProtocol()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        # Get rid of the original reference to the socket.
        probeClient.close()

        # A non-blocking recv will return "" if the connection is closed, as
        # desired.  If the connection has not been closed, because the
        # duplicate file descriptor is still open, it will fail with EAGAIN
        # instead.
        probeServer.setblocking(False)
        self.assertEqual(b"", probeServer.recv(1024))

        # This is a surprising circumstance, so it should be logged.
        format = ("%(protocolName)s (on %(hostAddress)r) does not "
                  "provide IFileDescriptorReceiver; closing file "
                  "descriptor received (from %(peerAddress)r).")
        clsName = "ConnectableProtocol"

        # Reverse host and peer, since the log event is from the client
        # perspective.
        expectedEvent = dict(hostAddress=server.peerAddress,
                             peerAddress=server.hostAddress,
                             protocolName=clsName,
                             format=format)

        for logEvent in events:
            for k, v in iteritems(expectedEvent):
                if v != logEvent.get(k):
                    break
            else:
                # No mismatches were found, stop looking at events
                break
        else:
            # No fully matching events were found, fail the test.
            self.fail("Expected event (%s) not found in logged events (%s)" %
                      (expectedEvent, pformat(events, )))

    if sendmsgSkip is not None:
        test_avoidLeakingFileDescriptors.skip = sendmsgSkip

    def test_descriptorDeliveredBeforeBytes(self):
        """
        L{IUNIXTransport.sendFileDescriptor} sends file descriptors before
        L{ITransport.write} sends normal bytes.
        """
        @implementer(IFileDescriptorReceiver)
        class RecordEvents(ConnectableProtocol):
            def connectionMade(self):
                ConnectableProtocol.connectionMade(self)
                self.events = []

            def fileDescriptorReceived(innerSelf, descriptor):
                self.addCleanup(close, descriptor)
                innerSelf.events.append(type(descriptor))

            def dataReceived(self, data):
                self.events.extend(data)

        cargo = socket()
        server = SendFileDescriptor(cargo.fileno(), b"junk")
        client = RecordEvents()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertEqual(int, client.events[0])
        if _PY3:
            self.assertEqual(b"junk", bytes(client.events[1:]))
        else:
            self.assertEqual(b"junk", b"".join(client.events[1:]))

    if sendmsgSkip is not None:
        test_descriptorDeliveredBeforeBytes.skip = sendmsgSkip
예제 #23
0
class FlockerClientTests(make_clientv1_tests()):
    """
    Interface tests for ``FlockerClient``.
    """
    @skipUnless(platform.isLinux(),
                "flocker-node-era currently requires Linux.")
    @skipUnless(which("flocker-node-era"),
                "flocker-node-era needs to be in $PATH.")
    def create_client(self):
        """
        Create a new ``FlockerClient`` instance pointing at a running control
        service REST API.

        :return: ``FlockerClient`` instance.
        """
        clock = Clock()
        _, self.port = find_free_port()
        self.persistence_service = ConfigurationPersistenceService(
            clock, FilePath(self.mktemp()))
        self.persistence_service.startService()
        self.cluster_state_service = ClusterStateService(reactor)
        self.cluster_state_service.startService()
        source = ChangeSource()
        # Prevent nodes being deleted by the state wiper.
        source.set_last_activity(reactor.seconds())
        self.era = UUID(check_output(["flocker-node-era"]))
        self.cluster_state_service.apply_changes_from_source(
            source=source,
            changes=[UpdateNodeStateEra(era=self.era, uuid=self.node_1.uuid)] +
            [
                NodeState(uuid=node.uuid, hostname=node.public_address)
                for node in [self.node_1, self.node_2]
            ],
        )
        self.addCleanup(self.cluster_state_service.stopService)
        self.addCleanup(self.persistence_service.stopService)
        credential_set, _ = get_credential_sets()
        credentials_path = FilePath(self.mktemp())
        credentials_path.makedirs()

        api_service = create_api_service(
            self.persistence_service,
            self.cluster_state_service,
            TCP4ServerEndpoint(reactor, self.port, interface=b"127.0.0.1"),
            rest_api_context_factory(
                credential_set.root.credential.certificate,
                credential_set.control),
            # Use consistent fake time for API results:
            clock)
        api_service.startService()
        self.addCleanup(api_service.stopService)

        credential_set.copy_to(credentials_path, user=True)
        return FlockerClient(reactor, b"127.0.0.1", self.port,
                             credentials_path.child(b"cluster.crt"),
                             credentials_path.child(b"user.crt"),
                             credentials_path.child(b"user.key"))

    def synchronize_state(self):
        deployment = self.persistence_service.get()
        # No IP address known, so use UUID for hostname
        node_states = [
            NodeState(uuid=node.uuid,
                      hostname=unicode(node.uuid),
                      applications=node.applications,
                      manifestations=node.manifestations,
                      paths={
                          manifestation.dataset_id:
                          FilePath(b"/flocker").child(
                              bytes(manifestation.dataset_id))
                          for manifestation in node.manifestations.values()
                      },
                      devices={}) for node in deployment.nodes
        ]
        self.cluster_state_service.apply_changes(node_states)

    def get_configuration_tag(self):
        return self.persistence_service.configuration_hash()

    @capture_logging(None)
    def test_logging(self, logger):
        """
        Successful HTTP requests are logged.
        """
        dataset_id = uuid4()
        d = self.client.create_dataset(primary=self.node_1.uuid,
                                       maximum_size=None,
                                       dataset_id=dataset_id)
        d.addCallback(lambda _: assertHasAction(
            self, logger, _LOG_HTTP_REQUEST, True,
            dict(url=b"https://127.0.0.1:{}/v1/configuration/datasets".format(
                self.port),
                 method=u"POST",
                 request_body=dict(primary=unicode(self.node_1.uuid),
                                   metadata={},
                                   dataset_id=unicode(dataset_id))),
            dict(response_body=dict(primary=unicode(self.node_1.uuid),
                                    metadata={},
                                    deleted=False,
                                    dataset_id=unicode(dataset_id)))))
        return d

    @capture_logging(None)
    def test_cross_process_logging(self, logger):
        """
        Eliot tasks can be traced from the HTTP client to the API server.
        """
        self.patch(rest_api, "_logger", logger)
        my_action = ActionType("my_action", [], [])
        with my_action():
            d = self.client.create_dataset(primary=self.node_1.uuid)

        def got_response(_):
            parent = LoggedAction.ofType(logger.messages, my_action)[0]
            child = LoggedAction.ofType(logger.messages, REQUEST)[0]
            self.assertIn(child, list(parent.descendants()))

        d.addCallback(got_response)
        return d

    @capture_logging(lambda self, logger: assertHasAction(
        self, logger, _LOG_HTTP_REQUEST, False,
        dict(url=b"https://127.0.0.1:{}/v1/configuration/datasets".format(
            self.port),
             method=u"POST",
             request_body=dict(primary=unicode(self.node_1.uuid),
                               maximum_size=u"notint",
                               metadata={})),
        {u'exception': u'flocker.apiclient._client.ResponseError'}))
    def test_unexpected_error(self, logger):
        """
        If the ``FlockerClient`` receives an unexpected HTTP response code it
        returns a ``ResponseError`` failure.
        """
        d = self.client.create_dataset(primary=self.node_1.uuid,
                                       maximum_size=u"notint")
        self.assertFailure(d, ResponseError)
        d.addCallback(lambda exc: self.assertEqual(exc.code, BAD_REQUEST))
        return d

    def test_unset_primary(self):
        """
        If the ``FlockerClient`` receives a dataset state where primary is
        ``None`` it parses it correctly.
        """
        dataset_id = uuid4()
        self.cluster_state_service.apply_changes([
            NonManifestDatasets(
                datasets={
                    unicode(dataset_id):
                    ModelDataset(dataset_id=unicode(dataset_id)),
                })
        ])
        d = self.client.list_datasets_state()
        d.addCallback(lambda states: self.assertEqual([
            DatasetState(dataset_id=dataset_id,
                         primary=None,
                         maximum_size=None,
                         path=None)
        ], states))
        return d

    def test_this_node_uuid_retry(self):
        """
        ``this_node_uuid`` retries if the node UUID is unknown.
        """
        # Pretend that the era for node 1 is something else; first try at
        # getting node UUID for real era will therefore fail:
        self.cluster_state_service.apply_changes(
            [UpdateNodeStateEra(era=uuid4(), uuid=self.node_1.uuid)])

        # When we lookup the DeploymentState the first time we'll set the
        # value to the correct one, so second try should succeed:
        def as_deployment(original=self.cluster_state_service.as_deployment):
            result = original()
            self.cluster_state_service.apply_changes(changes=[
                UpdateNodeStateEra(era=self.era, uuid=self.node_1.uuid)
            ])
            return result

        self.patch(self.cluster_state_service, "as_deployment", as_deployment)

        d = self.client.this_node_uuid()
        d.addCallback(self.assertEqual, self.node_1.uuid)
        return d

    def test_this_node_uuid_no_retry_on_other_responses(self):
        """
        ``this_node_uuid`` doesn't retry on unexpected responses.
        """
        # Cause 500 errors to be raised by the API endpoint:
        self.patch(self.cluster_state_service, "as_deployment", lambda: 1 / 0)
        return self.assertFailure(self.client.this_node_uuid(), ResponseError)
예제 #24
0
파일: test_unix.py 프로젝트: enmand/twisted
class UNIXTestsBuilder(UNIXFamilyMixin, ReactorBuilder, ConnectionTestsMixin):
    """
    Builder defining tests relating to L{IReactorUNIX}.
    """
    requiredInterfaces = (IReactorUNIX, )

    endpoints = UNIXCreator()

    def test_mode(self):
        """
        The UNIX socket created by L{IReactorUNIX.listenUNIX} is created with
        the mode specified.
        """
        self._modeTest('listenUNIX', self.mktemp(), ServerFactory())

    def test_listenOnLinuxAbstractNamespace(self):
        """
        On Linux, a UNIX socket path may begin with C{'\0'} to indicate a socket
        in the abstract namespace.  L{IReactorUNIX.listenUNIX} accepts such a
        path.
        """
        # Don't listen on a path longer than the maximum allowed.
        path = _abstractPath(self)
        reactor = self.buildReactor()
        port = reactor.listenUNIX('\0' + path, ServerFactory())
        self.assertEqual(port.getHost(), UNIXAddress('\0' + path))

    if not platform.isLinux():
        test_listenOnLinuxAbstractNamespace.skip = (
            'Abstract namespace UNIX sockets only supported on Linux.')

    def test_connectToLinuxAbstractNamespace(self):
        """
        L{IReactorUNIX.connectUNIX} also accepts a Linux abstract namespace
        path.
        """
        path = _abstractPath(self)
        reactor = self.buildReactor()
        connector = reactor.connectUNIX('\0' + path, ClientFactory())
        self.assertEqual(connector.getDestination(), UNIXAddress('\0' + path))

    if not platform.isLinux():
        test_connectToLinuxAbstractNamespace.skip = (
            'Abstract namespace UNIX sockets only supported on Linux.')

    def test_addresses(self):
        """
        A client's transport's C{getHost} and C{getPeer} return L{UNIXAddress}
        instances which have the filesystem path of the host and peer ends of
        the connection.
        """
        class SaveAddress(ConnectableProtocol):
            def makeConnection(self, transport):
                self.addresses = dict(host=transport.getHost(),
                                      peer=transport.getPeer())
                transport.loseConnection()

        server = SaveAddress()
        client = SaveAddress()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertEqual(server.addresses['host'], client.addresses['peer'])
        self.assertEqual(server.addresses['peer'], client.addresses['host'])

    def test_sendFileDescriptor(self):
        """
        L{IUNIXTransport.sendFileDescriptor} accepts an integer file descriptor
        and sends a copy of it to the process reading from the connection.
        """
        from socket import fromfd

        s = socket()
        s.bind(('', 0))
        server = SendFileDescriptor(s.fileno(), "junk")

        client = ReceiveFileDescriptor()
        d = client.waitForDescriptor()

        def checkDescriptor(descriptor):
            received = fromfd(descriptor, AF_INET, SOCK_STREAM)
            # Thanks for the free dup, fromfd()
            close(descriptor)

            # If the sockets have the same local address, they're probably the
            # same.
            self.assertEqual(s.getsockname(), received.getsockname())

            # But it would be cheating for them to be identified by the same
            # file descriptor.  The point was to get a copy, as we might get if
            # there were two processes involved here.
            self.assertNotEqual(s.fileno(), received.fileno())

        d.addCallback(checkDescriptor)
        d.addErrback(err, "Sending file descriptor encountered a problem")
        d.addBoth(lambda ignored: server.transport.loseConnection())

        runProtocolsWithReactor(self, server, client, self.endpoints)

    if sendmsgSkip is not None:
        test_sendFileDescriptor.skip = sendmsgSkip

    def test_sendFileDescriptorTriggersPauseProducing(self):
        """
        If a L{IUNIXTransport.sendFileDescriptor} call fills up the send buffer,
        any registered producer is paused.
        """
        class DoesNotRead(ConnectableProtocol):
            def connectionMade(self):
                self.transport.pauseProducing()

        class SendsManyFileDescriptors(ConnectableProtocol):
            paused = False

            def connectionMade(self):
                self.socket = socket()
                self.transport.registerProducer(self, True)

                def sender():
                    self.transport.sendFileDescriptor(self.socket.fileno())
                    self.transport.write("x")

                self.task = LoopingCall(sender)
                self.task.clock = self.transport.reactor
                self.task.start(0).addErrback(err, "Send loop failure")

            def stopProducing(self):
                self._disconnect()

            def resumeProducing(self):
                self._disconnect()

            def pauseProducing(self):
                self.paused = True
                self.transport.unregisterProducer()
                self._disconnect()

            def _disconnect(self):
                self.task.stop()
                self.transport.abortConnection()
                self.other.transport.abortConnection()

        server = SendsManyFileDescriptors()
        client = DoesNotRead()
        server.other = client
        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertTrue(server.paused,
                        "sendFileDescriptor producer was not paused")

    if sendmsgSkip is not None:
        test_sendFileDescriptorTriggersPauseProducing.skip = sendmsgSkip

    def test_fileDescriptorOverrun(self):
        """
        If L{IUNIXTransport.sendFileDescriptor} is used to queue a greater
        number of file descriptors than the number of bytes sent using
        L{ITransport.write}, the connection is closed and the protocol connected
        to the transport has its C{connectionLost} method called with a failure
        wrapping L{FileDescriptorOverrun}.
        """
        cargo = socket()
        server = SendFileDescriptor(cargo.fileno(), None)

        client = ReceiveFileDescriptor()
        result = []
        d = client.waitForDescriptor()
        d.addBoth(result.append)
        d.addBoth(lambda ignored: server.transport.loseConnection())

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertIsInstance(result[0], Failure)
        result[0].trap(ConnectionClosed)
        self.assertIsInstance(server.reason.value, FileDescriptorOverrun)

    if sendmsgSkip is not None:
        test_fileDescriptorOverrun.skip = sendmsgSkip

    def test_avoidLeakingFileDescriptors(self):
        """
        If associated with a protocol which does not provide
        L{IFileDescriptorReceiver}, file descriptors received by the
        L{IUNIXTransport} implementation are closed and a warning is emitted.
        """
        # To verify this, establish a connection.  Send one end of the
        # connection over the IUNIXTransport implementation.  After the copy
        # should no longer exist, close the original.  If the opposite end of
        # the connection decides the connection is closed, the copy does not
        # exist.
        from socket import socketpair
        probeClient, probeServer = socketpair()

        events = []
        addObserver(events.append)
        self.addCleanup(removeObserver, events.append)

        class RecordEndpointAddresses(SendFileDescriptor):
            def connectionMade(self):
                self.hostAddress = self.transport.getHost()
                self.peerAddress = self.transport.getPeer()
                SendFileDescriptor.connectionMade(self)

        server = RecordEndpointAddresses(probeClient.fileno(), "junk")
        client = ConnectableProtocol()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        # Get rid of the original reference to the socket.
        probeClient.close()

        # A non-blocking recv will return "" if the connection is closed, as
        # desired.  If the connection has not been closed, because the duplicate
        # file descriptor is still open, it will fail with EAGAIN instead.
        probeServer.setblocking(False)
        self.assertEqual("", probeServer.recv(1024))

        # This is a surprising circumstance, so it should be logged.
        format = ("%(protocolName)s (on %(hostAddress)r) does not "
                  "provide IFileDescriptorReceiver; closing file "
                  "descriptor received (from %(peerAddress)r).")
        clsName = "ConnectableProtocol"

        # Reverse host and peer, since the log event is from the client
        # perspective.
        expectedEvent = dict(hostAddress=server.peerAddress,
                             peerAddress=server.hostAddress,
                             protocolName=clsName,
                             format=format)

        for logEvent in events:
            for k, v in expectedEvent.iteritems():
                if v != logEvent.get(k):
                    break
            else:
                # No mismatches were found, stop looking at events
                break
        else:
            # No fully matching events were found, fail the test.
            self.fail("Expected event (%s) not found in logged events (%s)" %
                      (expectedEvent, pformat(events, )))

    if sendmsgSkip is not None:
        test_avoidLeakingFileDescriptors.skip = sendmsgSkip

    def test_descriptorDeliveredBeforeBytes(self):
        """
        L{IUNIXTransport.sendFileDescriptor} sends file descriptors before
        L{ITransport.write} sends normal bytes.
        """
        class RecordEvents(ConnectableProtocol):
            implements(IFileDescriptorReceiver)

            def connectionMade(self):
                ConnectableProtocol.connectionMade(self)
                self.events = []

            def fileDescriptorReceived(innerSelf, descriptor):
                self.addCleanup(close, descriptor)
                innerSelf.events.append(type(descriptor))

            def dataReceived(self, data):
                self.events.extend(data)

        cargo = socket()
        server = SendFileDescriptor(cargo.fileno(), "junk")
        client = RecordEvents()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertEqual([int, "j", "u", "n", "k"], client.events)

    if sendmsgSkip is not None:
        test_descriptorDeliveredBeforeBytes.skip = sendmsgSkip
예제 #25
0
class ReactorBuilder:
    """
    L{SynchronousTestCase} mixin which provides a reactor-creation API.  This
    mixin defines C{setUp} and C{tearDown}, so mix it in before
    L{SynchronousTestCase} or call its methods from the overridden ones in the
    subclass.

    @cvar skippedReactors: A dict mapping FQPN strings of reactors for
        which the tests defined by this class will be skipped to strings
        giving the skip message.
    @cvar requiredInterfaces: A C{list} of interfaces which the reactor must
        provide or these tests will be skipped.  The default, C{None}, means
        that no interfaces are required.
    @ivar reactorFactory: A no-argument callable which returns the reactor to
        use for testing.
    @ivar originalHandler: The SIGCHLD handler which was installed when setUp
        ran and which will be re-installed when tearDown runs.
    @ivar _reactors: A list of FQPN strings giving the reactors for which
        L{SynchronousTestCase}s will be created.
    """

    _reactors = [
        # Select works everywhere
        "twisted.internet.selectreactor.SelectReactor",
    ]

    if platform.isWindows():
        # PortableGtkReactor is only really interesting on Windows,
        # but not really Windows specific; if you want you can
        # temporarily move this up to the all-platforms list to test
        # it on other platforms.  It's not there in general because
        # it's not _really_ worth it to support on other platforms,
        # since no one really wants to use it on other platforms.
        _reactors.extend([
            "twisted.internet.gtk2reactor.PortableGtkReactor",
            "twisted.internet.gireactor.PortableGIReactor",
            "twisted.internet.gtk3reactor.PortableGtk3Reactor",
            "twisted.internet.win32eventreactor.Win32Reactor",
            "twisted.internet.iocpreactor.reactor.IOCPReactor"
        ])
    else:
        _reactors.extend([
            "twisted.internet.glib2reactor.Glib2Reactor",
            "twisted.internet.gtk2reactor.Gtk2Reactor",
            "twisted.internet.gireactor.GIReactor",
            "twisted.internet.gtk3reactor.Gtk3Reactor"
        ])
        if platform.isMacOSX():
            _reactors.append("twisted.internet.cfreactor.CFReactor")
        else:
            _reactors.extend([
                "twisted.internet.pollreactor.PollReactor",
                "twisted.internet.epollreactor.EPollReactor"
            ])
            if not platform.isLinux():
                # Presumably Linux is not going to start supporting kqueue, so
                # skip even trying this configuration.
                _reactors.extend([
                    # Support KQueue on non-OS-X POSIX platforms for now.
                    "twisted.internet.kqreactor.KQueueReactor",
                ])

    reactorFactory = None
    originalHandler = None
    requiredInterfaces = None
    skippedReactors = {}

    def setUp(self):
        """
        Clear the SIGCHLD handler, if there is one, to ensure an environment
        like the one which exists prior to a call to L{reactor.run}.
        """
        if not platform.isWindows():
            self.originalHandler = signal.signal(signal.SIGCHLD,
                                                 signal.SIG_DFL)

    def tearDown(self):
        """
        Restore the original SIGCHLD handler and reap processes as long as
        there seem to be any remaining.
        """
        if self.originalHandler is not None:
            signal.signal(signal.SIGCHLD, self.originalHandler)
        if process is not None:
            begin = time.time()
            while process.reapProcessHandlers:
                log.msg("ReactorBuilder.tearDown reaping some processes %r" %
                        (process.reapProcessHandlers, ))
                process.reapAllProcesses()

                # The process should exit on its own.  However, if it
                # doesn't, we're stuck in this loop forever.  To avoid
                # hanging the test suite, eventually give the process some
                # help exiting and move on.
                time.sleep(0.001)
                if time.time() - begin > 60:
                    for pid in process.reapProcessHandlers:
                        os.kill(pid, signal.SIGKILL)
                    raise Exception(
                        "Timeout waiting for child processes to exit: %r" %
                        (process.reapProcessHandlers, ))

    def unbuildReactor(self, reactor):
        """
        Clean up any resources which may have been allocated for the given
        reactor by its creation or by a test which used it.
        """
        # Chris says:
        #
        # XXX These explicit calls to clean up the waker (and any other
        # internal readers) should become obsolete when bug #3063 is
        # fixed. -radix, 2008-02-29. Fortunately it should probably cause an
        # error when bug #3063 is fixed, so it should be removed in the same
        # branch that fixes it.
        #
        # -exarkun
        reactor._uninstallHandler()
        if getattr(reactor, '_internalReaders', None) is not None:
            for reader in reactor._internalReaders:
                reactor.removeReader(reader)
                reader.connectionLost(None)
            reactor._internalReaders.clear()

        # Here's an extra thing unrelated to wakers but necessary for
        # cleaning up after the reactors we make.  -exarkun
        reactor.disconnectAll()

        # It would also be bad if any timed calls left over were allowed to
        # run.
        calls = reactor.getDelayedCalls()
        for c in calls:
            c.cancel()

    def buildReactor(self):
        """
        Create and return a reactor using C{self.reactorFactory}.
        """
        try:
            from twisted.internet.cfreactor import CFReactor
            from twisted.internet import reactor as globalReactor
        except ImportError:
            pass
        else:
            if (isinstance(globalReactor, CFReactor)
                    and self.reactorFactory is CFReactor):
                raise SkipTest(
                    "CFReactor uses APIs which manipulate global state, "
                    "so it's not safe to run its own reactor-builder tests "
                    "under itself")
        try:
            reactor = self.reactorFactory()
        except:
            # Unfortunately, not all errors which result in a reactor
            # being unusable are detectable without actually
            # instantiating the reactor.  So we catch some more here
            # and skip the test if necessary.  We also log it to aid
            # with debugging, but flush the logged error so the test
            # doesn't fail.
            log.err(None, "Failed to install reactor")
            self.flushLoggedErrors()
            raise SkipTest(Failure().getErrorMessage())
        else:
            if self.requiredInterfaces is not None:
                missing = [
                    required for required in self.requiredInterfaces
                    if not required.providedBy(reactor)
                ]
                if missing:
                    self.unbuildReactor(reactor)
                    raise SkipTest(
                        "%s does not provide %s" %
                        (fullyQualifiedName(reactor.__class__), ",".join(
                            [fullyQualifiedName(x) for x in missing])))
        self.addCleanup(self.unbuildReactor, reactor)
        return reactor

    def getTimeout(self):
        """
        Determine how long to run the test before considering it failed.

        @return: A C{int} or C{float} giving a number of seconds.
        """
        return acquireAttribute(self._parents, 'timeout',
                                DEFAULT_TIMEOUT_DURATION)

    def runReactor(self, reactor, timeout=None):
        """
        Run the reactor for at most the given amount of time.

        @param reactor: The reactor to run.

        @type timeout: C{int} or C{float}
        @param timeout: The maximum amount of time, specified in seconds, to
            allow the reactor to run.  If the reactor is still running after
            this much time has elapsed, it will be stopped and an exception
            raised.  If C{None}, the default test method timeout imposed by
            Trial will be used.  This depends on the L{IReactorTime}
            implementation of C{reactor} for correct operation.

        @raise TestTimeoutError: If the reactor is still running after
            C{timeout} seconds.
        """
        if timeout is None:
            timeout = self.getTimeout()

        timedOut = []

        def stop():
            timedOut.append(None)
            reactor.stop()

        timedOutCall = reactor.callLater(timeout, stop)
        reactor.run()
        if timedOut:
            raise TestTimeoutError("reactor still running after %s seconds" %
                                   (timeout, ))
        else:
            timedOutCall.cancel()

    def makeTestCaseClasses(cls):
        """
        Create a L{SynchronousTestCase} subclass which mixes in C{cls} for each
        known reactor and return a dict mapping their names to them.
        """
        classes = {}
        for reactor in cls._reactors:
            shortReactorName = reactor.split(".")[-1]
            name = (cls.__name__ + "." + shortReactorName).replace(".", "_")

            class testcase(cls, SynchronousTestCase):
                __module__ = cls.__module__
                if reactor in cls.skippedReactors:
                    skip = cls.skippedReactors[reactor]
                try:
                    reactorFactory = namedAny(reactor)
                except:
                    skip = Failure().getErrorMessage()

            testcase.__name__ = name
            classes[testcase.__name__] = testcase
        return classes

    makeTestCaseClasses = classmethod(makeTestCaseClasses)
예제 #26
0
from twisted.trial.unittest import SynchronousTestCase, SkipTest

from ..blockdevice import (BlockDeviceDeployer, LoopbackBlockDeviceAPI,
                           IBlockDeviceAPI, BlockDeviceVolume, UnknownVolume,
                           AlreadyAttachedVolume, CreateBlockDeviceDataset,
                           UnattachedVolume, _losetup_list_parse,
                           _losetup_list, _blockdevicevolume_from_dataset_id)

from ... import InParallel, IStateChange
from ...testtools import ideployer_tests_factory
from ....control import Dataset, Manifestation, Node, NodeState, Deployment

GIBIBYTE = 2**30
REALISTIC_BLOCKDEVICE_SIZE = 4 * GIBIBYTE

if not platform.isLinux():
    # The majority of Flocker isn't supported except on Linux - this test
    # module just happens to run some code that obviously breaks on some other
    # platforms.  Rather than skipping each test module individually it would
    # be nice to have some single global solution.  FLOC-1560, FLOC-1205
    skip = "flocker.node.agents.blockdevice is only supported on Linux"


class BlockDeviceDeployerTests(
        ideployer_tests_factory(lambda test: BlockDeviceDeployer(
            hostname=u"localhost",
            block_device_api=loopbackblockdeviceapi_for_test(test)))):
    """
    Tests for ``BlockDeviceDeployer``.
    """