Exemple #1
0
    async def init_app(self):
        '''Example urls:
           * redis://localhost
           * redis://redis
           * redis://172.18.176.220:7379
           * redis://sentryredis-1-002.shared.live.las1.mz-inc.com:6310
        '''
        # wait until all the redis nodes are reachable
        if self.probeRedisOnStartup:
            await self.waitForAllConnectionsToBeReady(
                timeout=self.redisStartupProbingTimeout)

        redis = self.redisClients.getRedisClient(STATS_APPKEY)

        serverStats = ServerStats(redis, STATS_APPKEY)
        self.app['stats'] = serverStats

        if self.enableStats:
            self.serverStatsTask = asyncio.ensure_future(serverStats.run())
            addTaskCleanup(self.serverStatsTask)

        if self.app.get('memory_debugger'):
            memoryDebugger = MemoryDebugger(
                noTraceMalloc=self.app.get('memory_debugger_no_tracemalloc'),
                printAllTasks=self.app.get('memory_debugger_print_all_tasks'),
            )
            self.app['memory_debugger'] = memoryDebugger

            self.memoryDebuggerTask = asyncio.ensure_future(
                memoryDebugger.run())
            addTaskCleanup(self.memoryDebuggerTask)
Exemple #2
0
    async def on_init(self):
        if not self.args['disable_debug_memory']:
            memoryDebugger = MemoryDebugger(noTraceMalloc=True)
            self.memoryDebuggerTask = asyncio.ensure_future(memoryDebugger.run())
            addTaskCleanup(self.memoryDebuggerTask)

        self.statsTask = asyncio.ensure_future(self.printStats())
        addTaskCleanup(self.statsTask)
Exemple #3
0
async def runSubscriber(url, credentials, channel, position, password):
    stream_sql = None
    args = {'password': password}

    task = asyncio.ensure_future(
        subscribeClient(url, credentials, channel, position, stream_sql,
                        MessageHandlerClass, args))
    addTaskCleanup(task)

    await task
Exemple #4
0
    async def connect(self):
        self.websocket = await websockets.connect(self.url)
        self.task = asyncio.ensure_future(self.waitForResponses())
        addTaskCleanup(self.task)

        if sys.version_info[:2] < (3, 7):
            self.stop = asyncio.get_event_loop().create_future()
        else:
            self.stop = asyncio.get_running_loop().create_future()

        role = self.creds['role']
        if role is None:
            raise ValueError('connect: Missing role')

        handshake = {
            "action": "auth/handshake",
            "body": {
                "data": {
                    "role": role
                },
                "method": "role_secret"
            },
        }

        response = await self.send(handshake)

        self.serverVersion = response['body']['data']['version']
        self.connectionId = response['body']['data']['connection_id']

        nonce = bytearray(response['body']['data']['nonce'], 'utf8')
        secret = bytearray(self.creds['secret'], 'utf8')

        challenge = {
            "action": "auth/authenticate",
            "body": {
                "method": "role_secret",
                "credentials": {
                    "hash": computeHash(secret, nonce)
                },
            },
        }
        await self.send(challenge)
Exemple #5
0
async def runClient(
    url,
    role,
    secret,
    channel,
    position,
    stream_sql,
    verbose,
    username,
    password,
    loop,
    inputs,
    stop,
):
    credentials = createCredentials(role, secret)

    q: asyncio.Queue[str] = asyncio.Queue(loop=loop)

    args = {'verbose': verbose, 'queue': q}

    task = asyncio.ensure_future(
        subscribeClient(url, credentials, channel, position, stream_sql,
                        MessageHandlerClass, args))
    addTaskCleanup(task)

    try:
        while True:
            incoming: asyncio.Future[Any] = asyncio.ensure_future(q.get())
            outgoing: asyncio.Future[Any] = asyncio.ensure_future(inputs.get())
            done: Set[asyncio.Future[Any]]
            pending: Set[asyncio.Future[Any]]
            done, pending = await asyncio.wait(
                [incoming, outgoing, stop],
                return_when=asyncio.FIRST_COMPLETED)

            # Cancel pending tasks to avoid leaking them.
            if incoming in pending:
                incoming.cancel()
            if outgoing in pending:
                outgoing.cancel()

            if incoming in done:
                try:
                    (message, position) = incoming.result()
                except websockets.exceptions.ConnectionClosed:
                    break
                else:
                    data = message.get('data', {})
                    user = data.get('user', 'unknown user')
                    text = data.get('text', '<invalid message>')
                    encrypted = data.get('encrypted', False)
                    if encrypted:
                        text = decrypt(text, password)
                    messageId = message.get('id')

                    # Use redis position to get a datetime
                    timestamp = position.split('-')[0]
                    dt = datetime.datetime.fromtimestamp(int(timestamp) / 1000)
                    dtFormatted = dt.strftime('[%H:%M:%S]')

                    maxUserNameLength = 12
                    padding = (maxUserNameLength - len(user)) * ' '

                    user = colorize(user)
                    print_during_input(
                        f'{dtFormatted} {padding} {user}: {text}')

            if outgoing in done:
                text = outgoing.result()

                messageId = uuid.uuid4().hex  # FIXME needed ?

                encrypted = False
                if password is not None:
                    text = encrypt(text, password)
                    encrypted = True

                message = {
                    'data': {
                        'encrypted': encrypted,
                        'user': username,
                        'text': text
                    },
                    'id': messageId,
                }

                await args['connection'].publish(channel, message)

            if stop in done:
                break

    except Exception as e:
        logging.error(f'Caught exception: {e}')

    finally:
        connection = args.get('connection')
        if connection is not None:
            closeStatus = await args['connection'].close()
            print_over_input(f"Connection closed: {closeStatus}.")

        task.cancel()
        await task

        exit_from_event_loop_thread(loop, stop)
Exemple #6
0
async def handleSubscribe(state: ConnectionState, ws, app: Dict, pdu: JsonDict,
                          serializedPdu: str):
    '''
    Client doesn't really needs it.
    '''
    body = pdu.get('body', {})
    channel = body.get('channel')

    subscriptionId = body.get('subscription_id')

    if channel is None and subscriptionId is None:
        errMsg = 'missing channel and subscription_id'
        logging.warning(errMsg)
        response = {
            "action": "rtm/subscribe/error",
            "id": pdu.get('id', 1),
            "body": {
                "error": errMsg
            },
        }
        await state.respond(ws, response)
        return

    maxSubs = app['max_subscriptions']
    if maxSubs >= 0 and len(state.subscriptions) + 1 > maxSubs:
        errMsg = f'subscriptions count over max limit: {maxSubs}'
        logging.warning(errMsg)
        response = {
            "action": "rtm/subscribe/error",
            "id": pdu.get('id', 1),
            "body": {
                "error": errMsg
            },
        }
        state.ok = False
        state.error = response
        await state.respond(ws, response)
        return

    if channel is None:
        channel = subscriptionId

    if subscriptionId is None:
        subscriptionId = channel

    filterStr = body.get('filter')
    hasFilter = filterStr not in ('', None)

    try:
        streamSQLFilter = StreamSqlFilter(filterStr) if hasFilter else None
    except InvalidStreamSQLError:
        errMsg = f'Invalid SQL expression {filterStr}'
        logging.warning(errMsg)
        response = {
            "action": "rtm/subscribe/error",
            "id": pdu.get('id', 1),
            "body": {
                "error": errMsg
            },
        }
        state.error = response
        await state.respond(ws, response)
        return

    if hasFilter and streamSQLFilter is not None:
        channel = streamSQLFilter.channel

    position = body.get('position')
    if not validatePosition(position):
        errMsg = f'Invalid position: {position}'
        logging.warning(errMsg)
        response = {
            "action": "rtm/subscribe/error",
            "id": pdu.get('id', 1),
            "body": {
                "error": errMsg
            },
        }
        state.ok = False
        state.error = response
        await state.respond(ws, response)
        return

    batchSize = body.get('batch_size', 1)
    try:
        batchSize = int(batchSize)
    except ValueError:
        errMsg = f'Invalid batch size: {batchSize}'
        logging.warning(errMsg)
        response = {
            "action": "rtm/subscribe/error",
            "id": pdu.get('id', 1),
            "body": {
                "error": errMsg
            },
        }
        state.ok = False
        state.error = response
        await state.respond(ws, response)
        return

    response = {
        "action": "rtm/subscribe/ok",
        "id": pdu.get('id', 1),
        "body": {
            # FIXME: we should set the position by querying
            # the redis stream, inside the MessageHandler
            "position": "1519190184-559034812775",
            "subscription_id": subscriptionId,
        },
    }

    class MessageHandlerClass(RedisSubscriberMessageHandlerClass):
        def __init__(self, args):
            self.cnt = 0
            self.cntPerSec = 0
            self.throttle = Throttle(seconds=1)
            self.ws = args['ws']
            self.subscriptionId = args['subscription_id']
            self.hasFilter = args['has_filter']
            self.streamSQLFilter = args['stream_sql_filter']
            self.appkey = args['appkey']
            self.serverStats = args['stats']
            self.state = args['state']
            self.subscribeResponse = args['subscribe_response']
            self.app = args['app']
            self.channel = args['channel']
            self.batchSize = args['batch_size']
            self.idIterator = itertools.count()

            self.messages = []

        def log(self, msg):
            self.state.log(msg)

        async def on_init(self, initInfo):
            response = self.subscribeResponse
            response['body'].update(initInfo)

            if not initInfo.get('success', False):
                msgId = response['id']
                response = {
                    'action': 'rtm/subscribe/error',
                    'id': msgId,
                    'body': {
                        'error':
                        'subscribe error: server cannot connect to redis'
                    },
                }

            # Send response.
            await self.state.respond(self.ws, response)

        async def handleMsg(self, msg: dict, position: str,
                            payloadSize: int) -> bool:

            # Input msg is the full serialized publish pdu.
            # Extract the real message out of it.
            msg = msg.get('body', {}).get('message')

            self.serverStats.updateSubscribed(self.state.role, payloadSize)
            self.serverStats.updateChannelSubscribed(self.channel, payloadSize)

            if self.hasFilter:
                filterOutput = self.streamSQLFilter.match(
                    msg.get('messages') or msg)  # noqa
                if not filterOutput:
                    return True
                else:
                    msg = filterOutput

            self.messages.append(msg)
            if len(self.messages) < self.batchSize:
                return True

            assert position is not None

            pdu = {
                "action": "rtm/subscription/data",
                "id": next(self.idIterator),
                "body": {
                    "subscription_id": self.subscriptionId,
                    "messages": self.messages,
                    "position": position,
                },
            }
            serializedPdu = json.dumps(pdu)
            self.state.log(f"> {serializedPdu} at position {position}")

            await self.ws.send(serializedPdu)

            self.cnt += len(self.messages)
            self.cntPerSec += len(self.messages)

            self.messages = []

            if self.throttle.exceedRate():
                return True

            self.state.log(f"#messages {self.cnt} msg/s {self.cntPerSec}")
            self.cntPerSec = 0

            return True

    appChannel = '{}::{}'.format(state.appkey, channel)

    # We need to create a new connection as reading from it will be blocking
    redisClient = app['redis_clients'].makeRedisClient()

    task = asyncio.ensure_future(
        redisSubscriber(
            redisClient.redis,
            appChannel,
            position,
            MessageHandlerClass,
            {
                'ws': ws,
                'subscription_id': subscriptionId,
                'has_filter': hasFilter,
                'stream_sql_filter': streamSQLFilter,
                'appkey': state.appkey,
                'stats': app['stats'],
                'state': state,
                'subscribe_response': response,
                'app': app,
                'channel': channel,
                'batch_size': batchSize,
            },
        ))
    addTaskCleanup(task)

    key = subscriptionId + state.connection_id
    state.subscriptions[key] = (task, state.role)

    app['stats'].incrSubscriptions(state.role)