Beispiel #1
0
class UnixSocketTest:
    def prepare(self, *, buffering):
        self._dir = mkdtemp()
        self._sock_file = os.path.join(self._dir, "sock")
        self.addCleanup(shutil.rmtree, self._dir)
        self._server = socket.socket(socket.AF_UNIX)
        self._server.bind(self._sock_file)
        self._server.listen(0)
        self._poller = Poller(buffering=buffering)
        self._poller.add_server(self._server)
        self.addCleanup(self._poller.close_all)

    def verify(self, data, *expected):
        client = socket.socket(socket.AF_UNIX)
        client.connect(self._sock_file)
        self.addCleanup(client.close)

        with timeout(0.1):
            result = self._poller.poll(0.01)
        cl_chan = result[0][0][1]

        with timeout(0.1):
            client.send(data)

        with timeout(0.1):
            result = self._poller.poll(0.01)

        self.assertEqual(result, [(a, cl_chan) for a in expected])
class TranslatorServer:
    def __init__(self, path, tcp_addr, translations, messages):
        self.running = True
        self.server = None
        self.client = None
        self.path = path
        self.addr = tcp_addr
        self._human2pa = translations['human2pa']
        self._pa2human = translations['pa2human']
        self._messages = messages
        self.poller = Poller(buffering='line')

    def run_forever(self):
        self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.server.bind(self.path)
        self.server.listen()
        self.poller.add_server(self.server)
        self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.tcp.bind(self.addr)
        self.tcp.listen()
        self.poller.add_server(self.tcp)
        self.addr = self.tcp.getsockname()
        while self.running:
            for data, channel in self.poller.poll():
                if channel in (self.server, self.tcp):
                    _LOGGER.info("Client connected")
                elif data:
                    _LOGGER.debug("Data: %r", data)
                    self._messages.append(data)
                    event = json.loads(data)
                    if 'text' in event:
                        intent = self._human2pa.get(event['text'],
                                                    "unintelligible")
                        _LOGGER.info("Translator: %s->%s", event['text'],
                                     intent)
                        result = {'intent': intent}
                    elif 'intent' in event:
                        text = self._pa2human.get(event['intent'], "errored")
                        _LOGGER.info("Translator: %s->%s", event['intent'],
                                     text)
                        result = {'text': text}
                    else:
                        result = {
                            "error": "Either 'intent' or 'text' required"
                        }
                    channel.write(json.dumps(result).encode() + b'\n')

    def stop(self):
        self.running = False
        if self.server is not None:
            self.server.close()
        if self.tcp is not None:
            self.tcp.close()
        if self.client is not None:
            self.client.close()

    def drop_client(self):
        if self.client is not None:
            self.client.close()
class TcpEndpoint:

    def __init__(self, serv, router_channel):
        self._router = router_channel
        self._serv = serv
        self._clients = {}
        self._client_names = {}
        self._usernames = {}
        self._poller = Poller(buffering='line')
        self._poller.add_server(serv)
        self._poller.register(self._router)

    def send_name(self):
        self._router.write(b'tcp\n')

    def _handle_user_data(self, data, channel):
        line = data.decode().strip()
        username = self._usernames.get(channel)
        if username is None:
            presence_msg = {"event": "presence",
                            "from": {"user": line,
                                        "channel": self._client_names[channel]},
                            "to": {"user": "******",
                                    "channel": "brain"}}
            self._router.write(json.dumps(presence_msg).encode(), b'\n')
            self._usernames[channel] = line
        elif data:
            if line:
                msg = {"message": line,
                        "from": {"user": username,
                                "channel": self._client_names[channel]},
                        "to": {"user": "******",
                                "channel": "brain"}}
                self._router.write(json.dumps(msg).encode(), b'\n')
        else:
            client_name = self._client_names[channel]
            gone_msg = {"event": "gone",
                        "from": {"user": username,
                                    "channel": client_name},
                        "to": {"user": "******",
                                "channel": "brain"}}
            self._router.write(json.dumps(gone_msg).encode(), b'\n')
            channel.close()
            del self._clients[client_name]
            del self._client_names[channel]
            del self._usernames[channel]

    def poll(self, timeout=None):
        for data, channel in self._poller.poll(timeout):
            _LOGGER.debug("Got %s from %s", data, channel)
            if channel == self._serv:
                addr, client = data
                client.write(b'Please enter your name> ')
                client_name = 'tcp:'+addr[0]+':'+str(addr[1])
                self._clients[client_name] = client
                self._client_names[client] = client_name
            elif channel == self._router:
                if not data:
                    raise RouterDisconnectedException()
                msg = json.loads(data.decode())
                client = self._clients.get(msg['to']['channel'])
                if client is not None:
                    client.write(b'Niege> '+msg['message'].encode()+b'\n')
            else:
                self._handle_user_data(data, channel)

    def shutdown(self):
        for client in self._client_names:
            client.close()

        self._clients.clear()
        self._client_names.clear()
        self._usernames.clear()
Beispiel #4
0
class Router:
    def __init__(self, serv, be):
        self._serv = serv
        self._be = be
        self._poller = Poller(buffering='line')
        self._poller.add_server(self._serv)
        self._poller.register(self._be)
        self._current_user = None
        self._endpoints = {}
        self._endpoint_names = {}
        self._endpoint_users = {}

    def _get_endpoint(self, endpoint_name):
        return self._endpoints.get(endpoint_name.split(':', 1)[0])

    def _add_endpoint(self, endpoint, name):
        _LOGGER.debug("Registered endpoint %s", name)
        self._endpoint_names[endpoint] = name
        self._endpoints[name] = endpoint
        self._endpoint_users[endpoint] = set()

    def _remove_endpoint(self, endpoint):
        name = self._endpoint_names.get(endpoint)
        if name is not None:
            _LOGGER.debug("Removing endpoint %s", name)
            del self._endpoint_names[endpoint]
            del self._endpoints[name]
            gone_msg = {
                "event": "gone",
                "from": {
                    "channel": name
                },
                "to": {
                    "user": "******",
                    "channel": "brain"
                }
            }
            for user in self._endpoint_users[endpoint]:
                self._switch_user(user)
                gone_msg["from"]["user"] = user
                self._be.write(json.dumps(gone_msg).encode() + b'\n')
            del self._endpoint_users[endpoint]
        endpoint.close()

    def _switch_user(self, user):
        if user != self._current_user:
            switch_msg = {"command": "switch-user", "user": user}
            self._be.write(json.dumps(switch_msg).encode(), b'\n')
            self._current_user = user

    def _route_user_message(self, channel, data):
        if channel not in self._endpoint_names:
            name = data.decode().strip()
            self._add_endpoint(channel, name)
            return
        msg = json.loads(data.decode())
        if msg['to']['channel'] != 'brain':
            return
        if 'from' in msg:
            user = msg["from"]["user"]
            self._switch_user(user)
            self._endpoint_users[channel].add(user)
        self._be.write(data)

    def tick(self, timeout=None):
        for data, channel in self._poller.poll(timeout):
            _LOGGER.debug("Got %s on %s", data, channel)
            if channel == self._serv:
                continue
            elif channel == self._be:
                if not data:
                    raise BrainDisconnectedException()
                msg = json.loads(data.decode())
                endpoint = self._get_endpoint(msg['to']['channel'])
                if endpoint is not None:
                    try:
                        endpoint.write(data)
                    except EndpointClosedException:
                        self._remove_endpoint(channel)
            else:
                if data:
                    self._route_user_message(channel, data)
                else:
                    self._remove_endpoint(channel)
Beispiel #5
0
class PollerTest(unittest.TestCase):
    def setUp(self):
        self._poller = Poller()

    def tearDown(self):
        self._poller.close_all()

    def test_blocking_poll(self):
        with self.assertRaisesRegex(Exception, "timeout"), timeout(0.02):
            self._poller.poll()

    def test_timed_poll(self):
        with timeout(0.02):
            result = self._poller.poll(0.01)
        self.assertEqual(list(result), [])

    def test_timed_out_poll(self):
        with self.assertRaisesRegex(Exception, "timeout"), timeout(0.02):
            self._poller.poll(0.03)

    def test_poll_data(self):
        chan = TestChannel()
        self._poller.register(chan)
        chan.put(b'hello\n')

        with timeout(0.02):
            result = self._poller.poll()

        self.assertEqual(list(result), [(b'hello\n', chan)])

    def test_poll_no_data(self):
        chan = TestChannel()
        self._poller.register(chan)

        with timeout(0.02):
            result = self._poller.poll(0.01)

        self.assertEqual(list(result), [])

    def test_poll_accept(self):
        client, cl_chan = _connect_and_get_client_channel(self, self._poller)

        client.send(b'hello\n')

        result = self._poller.poll(0.01)
        self.assertEqual(list(result), [(b'hello\n', cl_chan)])

    def test_close_all_channels(self):
        chan = TestChannel()
        self._poller.register(chan)

        self._poller.close_all()

        with self.assertRaises(EndpointClosedException):
            chan.read()

    def test_close_all_servers(self):
        serv = socket.socket()
        serv.bind(('127.0.0.1', 0))
        serv.listen(0)
        self.addCleanup(serv.close)
        self._poller.add_server(serv)

        self._poller.close_all()

        with timeout(0.01), self.assertRaises(OSError):
            serv.accept()

    def test_unregister(self):
        chan = TestChannel()
        self._poller.register(chan)
        chan.put(b'hello\n')

        self._poller.unregister(chan)

        with self.assertRaisesRegex(Exception, "timeout"), timeout(0.02):
            self._poller.poll()

    def test_unregister_twice(self):
        chan = TestChannel()
        self._poller.register(chan)
        chan.put(b'hello\n')

        self._poller.unregister(chan)
        self._poller.unregister(chan)

    def test_closed(self):
        chan = TestChannel()
        self._poller.register(chan)

        chan.close()

        with timeout(0.01):
            result = self._poller.poll()

        self.assertEquals(result, [(b'', chan)])

    def test_disconnect(self):
        client, cl_chan = _connect_and_get_client_channel(self, self._poller)

        client.close()
        result = list(self._poller.poll(0.01))

        self.assertEqual(result, [(b'', cl_chan)])

    def test_unregister_on_disconnect(self):
        client, cl_chan = _connect_and_get_client_channel(self, self._poller)
        client.close()
        self._poller.poll(0.01)

        result = list(self._poller.poll(0.01))

        self.assertEqual(result, [])