class IncomingEndpoint:

    def __init__(self, pipe_filename, router):
        self._pipe = pipe_filename
        self._router = router
        router.write(b'incoming\n')
        os.mkfifo(self._pipe)
        self._poller = Poller()
        self._open_pipe()

    def _open_pipe(self):
        f = os.open(self._pipe, os.O_RDONLY | os.O_NONBLOCK)
        self._poller.register(PipeChannel(f))

    def shutdown(self):
        if os.path.exists(self._pipe):
            os.unlink(self._pipe)
        self._poller.close_all()

    def poll(self, timeout=None):
        for data, _ in self._poller.poll(timeout):
            if data:
                self._router.write(data)
            else:
                self._open_pipe()
示例#2
0
class Router:
    def __init__(self, serv, be):
        self._serv = serv
        self._be = be
        self._poller = Poller(buffering='line')
        self._poller.add_server(self._serv)
        self._poller.register(self._be)
        self._current_user = None
        self._endpoints = {}
        self._endpoint_names = {}
        self._endpoint_users = {}

    def _get_endpoint(self, endpoint_name):
        return self._endpoints.get(endpoint_name.split(':', 1)[0])

    def _add_endpoint(self, endpoint, name):
        _LOGGER.debug("Registered endpoint %s", name)
        self._endpoint_names[endpoint] = name
        self._endpoints[name] = endpoint
        self._endpoint_users[endpoint] = set()

    def _remove_endpoint(self, endpoint):
        name = self._endpoint_names.get(endpoint)
        if name is not None:
            _LOGGER.debug("Removing endpoint %s", name)
            del self._endpoint_names[endpoint]
            del self._endpoints[name]
            gone_msg = {
                "event": "gone",
                "from": {
                    "channel": name
                },
                "to": {
                    "user": "******",
                    "channel": "brain"
                }
            }
            for user in self._endpoint_users[endpoint]:
                self._switch_user(user)
                gone_msg["from"]["user"] = user
                self._be.write(json.dumps(gone_msg).encode() + b'\n')
            del self._endpoint_users[endpoint]
        endpoint.close()

    def _switch_user(self, user):
        if user != self._current_user:
            switch_msg = {"command": "switch-user", "user": user}
            self._be.write(json.dumps(switch_msg).encode(), b'\n')
            self._current_user = user

    def _route_user_message(self, channel, data):
        if channel not in self._endpoint_names:
            name = data.decode().strip()
            self._add_endpoint(channel, name)
            return
        msg = json.loads(data.decode())
        if msg['to']['channel'] != 'brain':
            return
        if 'from' in msg:
            user = msg["from"]["user"]
            self._switch_user(user)
            self._endpoint_users[channel].add(user)
        self._be.write(data)

    def tick(self, timeout=None):
        for data, channel in self._poller.poll(timeout):
            _LOGGER.debug("Got %s on %s", data, channel)
            if channel == self._serv:
                continue
            elif channel == self._be:
                if not data:
                    raise BrainDisconnectedException()
                msg = json.loads(data.decode())
                endpoint = self._get_endpoint(msg['to']['channel'])
                if endpoint is not None:
                    try:
                        endpoint.write(data)
                    except EndpointClosedException:
                        self._remove_endpoint(channel)
            else:
                if data:
                    self._route_user_message(channel, data)
                else:
                    self._remove_endpoint(channel)
class TcpEndpoint:

    def __init__(self, serv, router_channel):
        self._router = router_channel
        self._serv = serv
        self._clients = {}
        self._client_names = {}
        self._usernames = {}
        self._poller = Poller(buffering='line')
        self._poller.add_server(serv)
        self._poller.register(self._router)

    def send_name(self):
        self._router.write(b'tcp\n')

    def _handle_user_data(self, data, channel):
        line = data.decode().strip()
        username = self._usernames.get(channel)
        if username is None:
            presence_msg = {"event": "presence",
                            "from": {"user": line,
                                        "channel": self._client_names[channel]},
                            "to": {"user": "******",
                                    "channel": "brain"}}
            self._router.write(json.dumps(presence_msg).encode(), b'\n')
            self._usernames[channel] = line
        elif data:
            if line:
                msg = {"message": line,
                        "from": {"user": username,
                                "channel": self._client_names[channel]},
                        "to": {"user": "******",
                                "channel": "brain"}}
                self._router.write(json.dumps(msg).encode(), b'\n')
        else:
            client_name = self._client_names[channel]
            gone_msg = {"event": "gone",
                        "from": {"user": username,
                                    "channel": client_name},
                        "to": {"user": "******",
                                "channel": "brain"}}
            self._router.write(json.dumps(gone_msg).encode(), b'\n')
            channel.close()
            del self._clients[client_name]
            del self._client_names[channel]
            del self._usernames[channel]

    def poll(self, timeout=None):
        for data, channel in self._poller.poll(timeout):
            _LOGGER.debug("Got %s from %s", data, channel)
            if channel == self._serv:
                addr, client = data
                client.write(b'Please enter your name> ')
                client_name = 'tcp:'+addr[0]+':'+str(addr[1])
                self._clients[client_name] = client
                self._client_names[client] = client_name
            elif channel == self._router:
                if not data:
                    raise RouterDisconnectedException()
                msg = json.loads(data.decode())
                client = self._clients.get(msg['to']['channel'])
                if client is not None:
                    client.write(b'Niege> '+msg['message'].encode()+b'\n')
            else:
                self._handle_user_data(data, channel)

    def shutdown(self):
        for client in self._client_names:
            client.close()

        self._clients.clear()
        self._client_names.clear()
        self._usernames.clear()
示例#4
0
class PollerTest(unittest.TestCase):
    def setUp(self):
        self._poller = Poller()

    def tearDown(self):
        self._poller.close_all()

    def test_blocking_poll(self):
        with self.assertRaisesRegex(Exception, "timeout"), timeout(0.02):
            self._poller.poll()

    def test_timed_poll(self):
        with timeout(0.02):
            result = self._poller.poll(0.01)
        self.assertEqual(list(result), [])

    def test_timed_out_poll(self):
        with self.assertRaisesRegex(Exception, "timeout"), timeout(0.02):
            self._poller.poll(0.03)

    def test_poll_data(self):
        chan = TestChannel()
        self._poller.register(chan)
        chan.put(b'hello\n')

        with timeout(0.02):
            result = self._poller.poll()

        self.assertEqual(list(result), [(b'hello\n', chan)])

    def test_poll_no_data(self):
        chan = TestChannel()
        self._poller.register(chan)

        with timeout(0.02):
            result = self._poller.poll(0.01)

        self.assertEqual(list(result), [])

    def test_poll_accept(self):
        client, cl_chan = _connect_and_get_client_channel(self, self._poller)

        client.send(b'hello\n')

        result = self._poller.poll(0.01)
        self.assertEqual(list(result), [(b'hello\n', cl_chan)])

    def test_close_all_channels(self):
        chan = TestChannel()
        self._poller.register(chan)

        self._poller.close_all()

        with self.assertRaises(EndpointClosedException):
            chan.read()

    def test_close_all_servers(self):
        serv = socket.socket()
        serv.bind(('127.0.0.1', 0))
        serv.listen(0)
        self.addCleanup(serv.close)
        self._poller.add_server(serv)

        self._poller.close_all()

        with timeout(0.01), self.assertRaises(OSError):
            serv.accept()

    def test_unregister(self):
        chan = TestChannel()
        self._poller.register(chan)
        chan.put(b'hello\n')

        self._poller.unregister(chan)

        with self.assertRaisesRegex(Exception, "timeout"), timeout(0.02):
            self._poller.poll()

    def test_unregister_twice(self):
        chan = TestChannel()
        self._poller.register(chan)
        chan.put(b'hello\n')

        self._poller.unregister(chan)
        self._poller.unregister(chan)

    def test_closed(self):
        chan = TestChannel()
        self._poller.register(chan)

        chan.close()

        with timeout(0.01):
            result = self._poller.poll()

        self.assertEquals(result, [(b'', chan)])

    def test_disconnect(self):
        client, cl_chan = _connect_and_get_client_channel(self, self._poller)

        client.close()
        result = list(self._poller.poll(0.01))

        self.assertEqual(result, [(b'', cl_chan)])

    def test_unregister_on_disconnect(self):
        client, cl_chan = _connect_and_get_client_channel(self, self._poller)
        client.close()
        self._poller.poll(0.01)

        result = list(self._poller.poll(0.01))

        self.assertEqual(result, [])