Пример #1
0
class PubSub(BasePubSub):
    """
    This class manages application PUB/SUB logic.
    """

    def __init__(self, application):
        super(PubSub, self).__init__(application)
        self.sub_stream = None

    def initialize(self):

        self.zmq_context = zmq.Context()
        options = self.application.settings["options"]

        self.zmq_pub_sub_proxy = options.zmq_pub_sub_proxy

        # create PUB socket to publish instance events into it
        publish_socket = self.zmq_context.socket(zmq.PUB)

        # do not try to send messages after closing
        publish_socket.setsockopt(zmq.LINGER, 0)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xsub = options.zmq_xsub
            publish_socket.connect(self.zmq_xsub)
        else:

            # application started without XPUB/XSUB proxy
            if options.zmq_pub_port_shift:
                # calculate zmq pub port number
                zmq_pub_port = options.port - options.zmq_pub_port_shift
            else:
                zmq_pub_port = options.zmq_pub_port

            self.zmq_pub_port = zmq_pub_port

            publish_socket.bind("tcp://%s:%s" % (options.zmq_pub_listen, str(self.zmq_pub_port)))

        # wrap pub socket into ZeroMQ stream
        self.pub_stream = ZMQStream(publish_socket)

        # create SUB socket listening to all events from all app instances
        subscribe_socket = self.zmq_context.socket(zmq.SUB)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xpub = options.zmq_xpub
            subscribe_socket.connect(self.zmq_xpub)
        else:
            # application started without XPUB/XSUB proxy
            self.zmq_sub_address = options.zmq_sub_address
            for address in self.zmq_sub_address:
                subscribe_socket.connect(address)

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE, six.u(CONTROL_CHANNEL))

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE, six.u(ADMIN_CHANNEL))

        def listen_socket():
            # wrap sub socket into ZeroMQ stream and set its on_recv callback
            self.sub_stream = ZMQStream(subscribe_socket)
            self.sub_stream.on_recv(self.dispatch_published_message)

        tornado.ioloop.IOLoop.instance().add_callback(listen_socket)

        if self.zmq_pub_sub_proxy:
            logger.info("ZeroMQ XPUB: {0}, XSUB: {1}".format(self.zmq_xpub, self.zmq_xsub))
        else:
            logger.info("ZeroMQ PUB - {0}; subscribed to {1}".format(self.zmq_pub_port, self.zmq_sub_address))

    def publish(self, channel, message, method=None):
        """
        Publish message into channel of stream.
        """
        method = method or self.DEFAULT_PUBLISH_METHOD
        message["message_type"] = method
        message = json_encode(message)
        to_publish = [utf8(channel), utf8(message)]
        self.pub_stream.send_multipart(to_publish)

    @coroutine
    def dispatch_published_message(self, multipart_message):
        """
        Got message, decide what is it and dispatch into right
        application handler.
        """
        channel = multipart_message[0]
        message_data = multipart_message[1]
        if six.PY3:
            channel = channel.decode()
            message_data = message_data.decode()
        if channel == CONTROL_CHANNEL:
            yield self.handle_control_message(message_data)
        elif channel == ADMIN_CHANNEL:
            yield self.handle_admin_message(message_data)
        else:
            yield self.handle_channel_message(channel, message_data)

    def subscribe_key(self, subscription_key):
        self.sub_stream.setsockopt_string(zmq.SUBSCRIBE, six.u(subscription_key))

    def unsubscribe_key(self, subscription_key):
        self.sub_stream.setsockopt_string(zmq.UNSUBSCRIBE, six.u(subscription_key))
Пример #2
0
class ZmqPubSub(object):
    """
    This class manages application PUB/SUB logic.
    """

    def __init__(self, application):
        self.application = application
        self.subscriptions = {}
        self.sub_stream = None

    def init_sockets(self):
        self.zmq_context = zmq.Context()
        options = self.application.settings["options"]

        self.zmq_pub_sub_proxy = options.zmq_pub_sub_proxy

        # create PUB socket to publish instance events into it
        publish_socket = self.zmq_context.socket(zmq.PUB)

        # do not try to send messages after closing
        publish_socket.setsockopt(zmq.LINGER, 0)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xsub = options.zmq_xsub
            publish_socket.connect(self.zmq_xsub)
        else:

            # application started without XPUB/XSUB proxy
            if options.zmq_pub_port_shift:
                # calculate zmq pub port number
                zmq_pub_port = options.port - options.zmq_pub_port_shift
            else:
                zmq_pub_port = options.zmq_pub_port

            self.zmq_pub_port = zmq_pub_port

            publish_socket.bind("tcp://%s:%s" % (options.zmq_pub_listen, str(self.zmq_pub_port)))

        # wrap pub socket into ZeroMQ stream
        self.pub_stream = ZMQStream(publish_socket)

        # create SUB socket listening to all events from all app instances
        subscribe_socket = self.zmq_context.socket(zmq.SUB)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xpub = options.zmq_xpub
            subscribe_socket.connect(self.zmq_xpub)
        else:
            # application started without XPUB/XSUB proxy
            self.zmq_sub_address = options.zmq_sub_address
            for address in self.zmq_sub_address:
                subscribe_socket.connect(address)

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE, six.u(CONTROL_CHANNEL))

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE, six.u(ADMIN_CHANNEL))

        def listen_socket():
            # wrap sub socket into ZeroMQ stream and set its on_recv callback
            self.sub_stream = ZMQStream(subscribe_socket)
            self.sub_stream.on_recv(self.dispatch_published_message)

        tornado.ioloop.IOLoop.instance().add_callback(listen_socket)

        if self.zmq_pub_sub_proxy:
            logger.info("ZeroMQ XPUB: {0}, XSUB: {1}".format(self.zmq_xpub, self.zmq_xsub))
        else:
            logger.info("ZeroMQ PUB - {0}; subscribed to {1}".format(self.zmq_pub_port, self.zmq_sub_address))

    def publish(self, channel, message, method=None):
        """
        Publish message into channel of stream.
        """
        method = method or DEFAULT_PUBLISH_METHOD
        to_publish = [utf8(channel), utf8(method), utf8(message)]
        self.pub_stream.send_multipart(to_publish)

    def get_subscription_key(self, project_id, namespace, channel):
        """
        Create subscription name to catch messages from specific
        project, namespace and channel.
        """
        return str(CHANNEL_NAME_SEPARATOR.join([project_id, namespace, channel, CHANNEL_SUFFIX]))

    @coroutine
    def dispatch_published_message(self, multipart_message):
        """
        Got message, decide what is it and dispatch into right
        application handler.
        """
        channel = multipart_message[0]
        method = multipart_message[1]
        message_data = multipart_message[2]
        if six.PY3:
            message_data = message_data.decode()
        if channel == CONTROL_CHANNEL:
            yield self.handle_control_message(message_data)
        elif channel == ADMIN_CHANNEL:
            yield self.handle_admin_message(message_data)
        else:
            yield self.handle_channel_message(channel, method, message_data)

    @coroutine
    def handle_admin_message(self, message):
        for uid, connection in six.iteritems(self.application.admin_connections):
            if uid in self.application.admin_connections:
                connection.send(message)

    @coroutine
    def handle_channel_message(self, channel, method, message):
        if channel not in self.subscriptions:
            raise Return((True, None))

        response = Response(method=method, body=message)
        prepared_response = response.as_message()

        for uid, client in six.iteritems(self.subscriptions[channel]):
            if channel in self.subscriptions and uid in self.subscriptions[channel]:
                client.send(prepared_response)

    @coroutine
    def handle_control_message(self, message):
        """
        Handle control message.
        """
        message = json_decode(message)

        app_id = message.get("app_id")
        method = message.get("method")
        params = message.get("params")

        if app_id and app_id == self.application.uid:
            # application id must be set when we don't want to do
            # make things twice for the same application. Setting
            # app_id means that we don't want to process control
            # message when it is appear in application instance if
            # application uid matches app_id
            raise Return((True, None))

        func = getattr(self.application, "handle_%s" % method, None)
        if not func:
            raise Return((None, "method not found"))

        result, error = yield func(params)
        raise Return((result, error))

    def add_subscription(self, project_id, namespace_name, channel, client):
        """
        Subscribe application on channel if necessary and register client
        to receive messages from that channel.
        """
        subscription_key = self.get_subscription_key(project_id, namespace_name, channel)
        self.sub_stream.setsockopt_string(zmq.SUBSCRIBE, six.u(subscription_key))

        if subscription_key not in self.subscriptions:
            self.subscriptions[subscription_key] = {}

        self.subscriptions[subscription_key][client.uid] = client

    def remove_subscription(self, project_id, namespace_name, channel, client):
        """
        Unsubscribe application from channel if necessary and unregister client
        from receiving messages from that channel.
        """
        subscription_key = self.get_subscription_key(project_id, namespace_name, channel)

        try:
            del self.subscriptions[subscription_key][client.uid]
        except KeyError:
            pass

        try:
            if not self.subscriptions[subscription_key]:
                self.sub_stream.setsockopt_string(zmq.UNSUBSCRIBE, six.u(subscription_key))
                del self.subscriptions[subscription_key]
        except KeyError:
            pass
Пример #3
0
class PubSub(BasePubSub):
    """
    This class manages application PUB/SUB logic.
    """
    NAME = 'ZeroMQ'

    def __init__(self, application):
        super(PubSub, self).__init__(application)
        self.sub_stream = None
        self.pub_stream = None
        self.zmq_context = None
        self.zmq_pub_sub_proxy = None
        self.zmq_xpub = None
        self.zmq_xsub = None
        self.zmq_pub_port = None
        self.zmq_sub_address = None

    def initialize(self):

        self.zmq_context = zmq.Context()
        options = self.application.settings['options']

        self.zmq_pub_sub_proxy = options.zmq_pub_sub_proxy

        # create PUB socket to publish instance events into it
        publish_socket = self.zmq_context.socket(zmq.PUB)

        # do not try to send messages after closing
        publish_socket.setsockopt(zmq.LINGER, 0)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xsub = options.zmq_xsub
            publish_socket.connect(self.zmq_xsub)
        else:

            # application started without XPUB/XSUB proxy
            if options.zmq_pub_port_shift:
                # calculate zmq pub port number
                zmq_pub_port = options.port - options.zmq_pub_port_shift
            else:
                zmq_pub_port = options.zmq_pub_port

            self.zmq_pub_port = zmq_pub_port

            publish_socket.bind(
                "tcp://%s:%s" %
                (options.zmq_pub_listen, str(self.zmq_pub_port)))

        # wrap pub socket into ZeroMQ stream
        self.pub_stream = ZMQStream(publish_socket)

        # create SUB socket listening to all events from all app instances
        subscribe_socket = self.zmq_context.socket(zmq.SUB)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xpub = options.zmq_xpub
            subscribe_socket.connect(self.zmq_xpub)
        else:
            # application started without XPUB/XSUB proxy
            self.zmq_sub_address = options.zmq_sub_address
            for address in self.zmq_sub_address:
                subscribe_socket.connect(address)

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE,
                                           six.u(CONTROL_CHANNEL))

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE, six.u(ADMIN_CHANNEL))

        def listen_socket():
            # wrap sub socket into ZeroMQ stream and set its on_recv callback
            self.sub_stream = ZMQStream(subscribe_socket)
            self.sub_stream.on_recv(self.dispatch_published_message)

        tornado.ioloop.IOLoop.instance().add_callback(listen_socket)

        if self.zmq_pub_sub_proxy:
            logger.info("ZeroMQ XPUB: {0}, XSUB: {1}".format(
                self.zmq_xpub, self.zmq_xsub))
        else:
            logger.info("ZeroMQ PUB - {0}; subscribed to {1}".format(
                self.zmq_pub_port, self.zmq_sub_address))

    def publish(self, channel, message, method=None):
        """
        Publish message into channel of stream.
        """
        method = method or self.DEFAULT_PUBLISH_METHOD
        message["message_type"] = method
        message = json_encode(message)
        to_publish = [utf8(channel), utf8(message)]
        self.pub_stream.send_multipart(to_publish)

    @coroutine
    def dispatch_published_message(self, multipart_message):
        """
        Got message, decide what is it and dispatch into right
        application handler.
        """
        channel = multipart_message[0]
        if six.PY3:
            channel = channel.decode()

        message_data = json_decode(multipart_message[1])

        if channel == CONTROL_CHANNEL:
            yield self.handle_control_message(message_data)
        elif channel == ADMIN_CHANNEL:
            yield self.handle_admin_message(message_data)
        else:
            yield self.handle_channel_message(channel, message_data)

    def subscribe_key(self, subscription_key):
        self.sub_stream.setsockopt_string(zmq.SUBSCRIBE,
                                          six.u(subscription_key))

    def unsubscribe_key(self, subscription_key):
        self.sub_stream.setsockopt_string(zmq.UNSUBSCRIBE,
                                          six.u(subscription_key))

    def clean(self):
        """
        Properly close ZeroMQ sockets.
        """
        if hasattr(self, 'pub_stream') and self.pub_stream:
            self.pub_stream.close()
        if hasattr(self, 'sub_stream') and self.sub_stream:
            self.sub_stream.stop_on_recv()
            self.sub_stream.close()
Пример #4
0
class Connection(object):
    """
    This is a base class describing a single connection of client from
    web browser.
    """
    # maximum auth validation requests before returning error to client
    MAX_AUTH_ATTEMPTS = 5

    # interval unit in milliseconds for back off
    BACK_OFF_INTERVAL = 100

    # maximum timeout between authorization attempts in back off
    BACK_OFF_MAX_TIMEOUT = 5000

    def close_connection(self):
        """
        General method for closing connection.
        """
        if isinstance(self, (SockJSConnection, )):
            self.close()

    def send_message(self, message):
        """
        Send message to client
        """
        if isinstance(self, SockJSConnection):
            self.send(message)

    def send_ack(self, msg_id=None, method=None, result=None, error=None):
        self.send_message(
            self.make_ack(
                msg_id=msg_id,
                method=method,
                result=result,
                error=error
            )
        )

    def make_ack(self, msg_id=None, method=None, result=None, error=None):

        to_return = {
            'ack': True,
            'id': msg_id,
            'method': method,
            'result': result,
            'error': error
        }
        return json_encode(to_return)

    @coroutine
    def handle_auth(self, params):

        if self.is_authenticated:
            raise Return((True, None))

        token = params["token"]
        user = params["user"]
        project_id = params['project_id']
        permissions = params["permissions"]

        project, error = yield state.get_project_by_id(project_id)
        if error:
            self.close_connection()
        if not project:
            raise Return((None, "project not found"))

        secret_key = project['secret_key']

        if token != auth.get_client_token(secret_key, project_id, user):
            raise Return((None, "invalid token"))

        if user and project.get('validate_url', None):

            http_client = AsyncHTTPClient()
            request = HTTPRequest(
                project['validate_url'],
                method="POST",
                body=json_encode({'user': user, 'permissions': permissions}),
                request_timeout=1
            )

            max_auth_attempts = project.get(
                'auth_attempts'
            ) or self.MAX_AUTH_ATTEMPTS

            back_off_interval = project.get(
                'back_off_interval'
            ) or self.BACK_OFF_INTERVAL

            back_off_max_timeout = project.get(
                'back_off_max_timeout'
            ) or self.BACK_OFF_MAX_TIMEOUT

            attempts = 0

            while attempts < max_auth_attempts:

                # get current timeout for project
                current_attempts = self.application.back_off.setdefault(project_id, 0)

                factor = random.randint(0, 2**current_attempts-1)
                timeout = factor*back_off_interval

                if timeout > back_off_max_timeout:
                    timeout = back_off_max_timeout

                # wait before next authorization request attempt
                yield sleep(float(timeout)/1000)

                try:
                    response = yield http_client.fetch(request)
                except BaseException:
                    # let it fail and try again after some timeout
                    # until we have auth attempts
                    pass
                else:
                    # reset back-off attempts
                    self.application.back_off[project_id] = 0

                    if response.code == 200:
                        self.is_authenticated = True
                        break
                    elif response.code == 403:
                        raise Return((None, "permission denied"))
                attempts += 1
                self.application.back_off[project_id] += 1
        else:
            self.is_authenticated = True

        if not self.is_authenticated:
            raise Return((None, "permission validation error"))

        categories, error = yield state.get_project_categories(project)
        if error:
            self.close_connection()

        self.categories = {}
        for category in categories:
            if not permissions or (permissions and category['name'] in permissions):
                self.categories[category['name']] = category

        self.uid = uuid.uuid4().hex
        self.project = project
        self.permissions = permissions
        self.user = user
        self.channels = {}
        self.start_heartbeat()

        # allow broadcast from client only into bidirectional categories
        self.bidirectional_categories = {}
        for category_name, category in six.iteritems(self.categories):
            if category.get('bidirectional', False):
                self.bidirectional_categories[category_name] = category

        context = zmq.Context()
        subscribe_socket = context.socket(zmq.SUB)

        if self.application.zmq_pub_sub_proxy:
            subscribe_socket.connect(self.application.zmq_xpub)
        else:
            for address in self.application.zmq_sub_address:
                subscribe_socket.connect(address)

        self.sub_stream = ZMQStream(subscribe_socket)
        self.sub_stream.on_recv(self.on_message_published)

        raise Return((True, None))

    @coroutine
    def handle_subscribe(self, params):
        """
        Subscribe authenticated connection on channels.
        """
        subscribe_to = params.get('to')

        if not subscribe_to:
            raise Return((True, None))

        project_id = self.project['_id']

        connections = self.application.connections

        if project_id not in connections:
            connections[project_id] = {}

        if self.user and self.user not in connections:
            connections[project_id][self.user] = {}

        if self.user:
            connections[project_id][self.user][self.uid] = self

        for category_name, channels in six.iteritems(subscribe_to):

            if category_name not in self.categories:
                # attempt to subscribe on not allowed category
                continue

            if not channels or not isinstance(channels, list):
                # attempt to subscribe without channels provided
                continue

            category_id = self.categories[category_name]['_id']

            allowed_channels = self.permissions.get(category_name) if self.permissions else []

            for channel in channels:

                if not isinstance(allowed_channels, list):
                    continue

                if allowed_channels and channel not in allowed_channels:
                    # attempt to subscribe on not allowed channel
                    continue

                channel_to_subscribe = rpc.create_channel_name(
                    project_id,
                    category_id,
                    channel
                )
                self.sub_stream.setsockopt_string(
                    zmq.SUBSCRIBE, six.u(channel_to_subscribe)
                )

                if category_name not in self.channels:
                    self.channels[category_name] = {}
                self.channels[category_name][channel_to_subscribe] = True

        raise Return((True, None))

    @coroutine
    def handle_unsubscribe(self, params):
        unsubscribe_from = params.get('from')

        if not unsubscribe_from:
            raise Return((True, None))

        project_id = self.project['_id']

        for category_name, channels in six.iteritems(unsubscribe_from):

            if category_name not in self.categories:
                # attempt to unsubscribe from not allowed category
                continue

            if not channels or not isinstance(channels, list):
                # attempt to unsubscribe from unknown channels
                continue

            category_id = self.categories[category_name]['_id']

            for channel in channels:

                allowed_channels = self.permissions[category_name] if self.permissions else []

                if allowed_channels and channel not in allowed_channels:
                    # attempt to unsubscribe from not allowed channel
                    continue

                channel_to_unsubscribe = rpc.create_channel_name(
                    project_id,
                    category_id,
                    channel
                )
                self.sub_stream.setsockopt_string(
                    zmq.UNSUBSCRIBE, six.u(channel_to_unsubscribe)
                )

                try:
                    del self.channels[category_name][channel_to_unsubscribe]
                except KeyError:
                    pass

        raise Return((True, None))

    @coroutine
    def handle_broadcast(self, params):

        category = params.get('category')
        channel = params.get('channel')

        if category not in self.categories:
            raise Return((None, 'category does not exist or permission denied'))

        if category not in self.bidirectional_categories:
            raise Return((None, 'one-way category'))

        allowed_channels = self.permissions.get(category) if self.permissions else []

        if allowed_channels and channel not in allowed_channels:
            # attempt to broadcast into not allowed channel
            raise Return((None, 'channel permission denied'))

        result, error = yield rpc.process_broadcast(
            self.application,
            self.project,
            self.bidirectional_categories,
            params
        )

        raise Return((result, error))

    @coroutine
    def on_centrifuge_connection_message(self, message):
        """
        Called when message from client received.
        """
        try:
            data = json_decode(message)
        except ValueError:
            self.send_ack(error='malformed JSON data')
            raise Return(False)

        try:
            validate(data, req_schema)
        except ValidationError as e:
            self.send_ack(error=str(e))

        msg_id = data.get('id', None)
        method = data.get('method')
        params = data.get('params')

        if method != 'auth' and not self.is_authenticated:
            self.send_ack(error='unauthorized')
            raise Return(True)

        func = getattr(self, 'handle_%s' % method, None)

        if not func:
            self.send_ack(
                msg_id=msg_id,
                method=method,
                error="unknown method %s" % method
            )

        try:
            validate(params, client_params_schema[method])
        except ValidationError as e:
            self.send_ack(msg_id=msg_id, method=method, error=str(e))
            raise Return(True)

        result, error = yield func(params)

        self.send_ack(msg_id=msg_id, method=method, result=result, error=error)

        raise Return(True)

    def start_heartbeat(self):
        """
        In ideal case we work with websocket connections with heartbeat available
        by default. But there are lots of other transports whose heartbeat must be
        started manually. Do it here.
        """
        if isinstance(self, SockJSConnection):
            if self.session:
                if self.session.transport_name != 'rawwebsocket':
                    self.session.start_heartbeat()
            else:
                self.close_connection()

    def on_message_published(self, message):
        """
        Called when message received from one of channels client subscribed to.
        """
        actual_message = message[0]
        if six.PY3:
            actual_message = actual_message.decode()
        self.send_message(
            actual_message.split(rpc.CHANNEL_DATA_SEPARATOR, 1)[1]
        )

    def clean_up(self):
        """
        Unsubscribe connection from channels, clean up zmq sockets.
        """
        if hasattr(self, 'sub_stream') and not self.sub_stream.closed():
            self.sub_stream.on_recv(None)
            self.sub_stream.close()

        if not self.is_authenticated:
            return

        project_id = self.project['_id']

        connections = self.application.connections

        self.channels = None

        if not project_id in connections:
            return

        if not self.user in connections[project_id]:
            return

        try:
            del connections[project_id][self.user][self.uid]
        except KeyError:
            pass

        # clean connections
        if not connections[project_id][self.user]:
            try:
                del connections[project_id][self.user]
            except KeyError:
                pass
            if not connections[project_id]:
                try:
                    del connections[project_id]
                except KeyError:
                    pass

    def on_centrifuge_connection_open(self):
        logger.info('client connected')
        self.is_authenticated = False

    def on_centrifuge_connection_close(self):
        logger.info('client disconnected')
        self.clean_up()
Пример #5
0
class Connection(object):
    """
    This is a base class describing a single connection of client from
    web browser.
    """
    # maximum auth validation requests before returning error to client
    MAX_AUTH_ATTEMPTS = 5

    # interval unit in milliseconds for back off
    BACK_OFF_INTERVAL = 100

    # maximum timeout between authorization attempts in back off
    BACK_OFF_MAX_TIMEOUT = 5000

    def close_connection(self):
        """
        General method for closing connection.
        """
        if isinstance(self, (SockJSConnection, )):
            self.close()

    def send_message(self, message):
        """
        Send message to client
        """
        if isinstance(self, SockJSConnection):
            self.send(message)

    def send_ack(self, msg_id=None, method=None, result=None, error=None):
        self.send_message(
            self.make_ack(msg_id=msg_id,
                          method=method,
                          result=result,
                          error=error))

    def make_ack(self, msg_id=None, method=None, result=None, error=None):

        to_return = {
            'ack': True,
            'id': msg_id,
            'method': method,
            'result': result,
            'error': error
        }
        return json_encode(to_return)

    @coroutine
    def handle_auth(self, params):

        if self.is_authenticated:
            raise Return((True, None))

        token = params["token"]
        user = params["user"]
        project_id = params['project_id']
        permissions = params["permissions"]

        project, error = yield state.get_project_by_id(project_id)
        if error:
            self.close_connection()
        if not project:
            raise Return((None, "project not found"))

        secret_key = project['secret_key']

        if token != auth.get_client_token(secret_key, project_id, user):
            raise Return((None, "invalid token"))

        if user and project.get('validate_url', None):

            http_client = AsyncHTTPClient()
            request = HTTPRequest(project['validate_url'],
                                  method="POST",
                                  body=json_encode({
                                      'user': user,
                                      'permissions': permissions
                                  }),
                                  request_timeout=1)

            max_auth_attempts = project.get(
                'auth_attempts') or self.MAX_AUTH_ATTEMPTS

            back_off_interval = project.get(
                'back_off_interval') or self.BACK_OFF_INTERVAL

            back_off_max_timeout = project.get(
                'back_off_max_timeout') or self.BACK_OFF_MAX_TIMEOUT

            attempts = 0

            while attempts < max_auth_attempts:

                # get current timeout for project
                current_attempts = self.application.back_off.setdefault(
                    project_id, 0)

                factor = random.randint(0, 2**current_attempts - 1)
                timeout = factor * back_off_interval

                if timeout > back_off_max_timeout:
                    timeout = back_off_max_timeout

                # wait before next authorization request attempt
                yield sleep(float(timeout) / 1000)

                try:
                    response = yield http_client.fetch(request)
                except BaseException:
                    # let it fail and try again after some timeout
                    # until we have auth attempts
                    pass
                else:
                    # reset back-off attempts
                    self.application.back_off[project_id] = 0

                    if response.code == 200:
                        self.is_authenticated = True
                        break
                    elif response.code == 403:
                        raise Return((None, "permission denied"))
                attempts += 1
                self.application.back_off[project_id] += 1
        else:
            self.is_authenticated = True

        if not self.is_authenticated:
            raise Return((None, "permission validation error"))

        categories, error = yield state.get_project_categories(project)
        if error:
            self.close_connection()

        self.categories = {}
        for category in categories:
            if not permissions or (permissions
                                   and category['name'] in permissions):
                self.categories[category['name']] = category

        self.uid = uuid.uuid4().hex
        self.project = project
        self.permissions = permissions
        self.user = user
        self.channels = {}
        self.start_heartbeat()

        # allow broadcast from client only into bidirectional categories
        self.bidirectional_categories = {}
        for category_name, category in six.iteritems(self.categories):
            if category.get('bidirectional', False):
                self.bidirectional_categories[category_name] = category

        context = zmq.Context()
        subscribe_socket = context.socket(zmq.SUB)

        if self.application.zmq_pub_sub_proxy:
            subscribe_socket.connect(self.application.zmq_xpub)
        else:
            for address in self.application.zmq_sub_address:
                subscribe_socket.connect(address)

        self.sub_stream = ZMQStream(subscribe_socket)
        self.sub_stream.on_recv(self.on_message_published)

        raise Return((True, None))

    @coroutine
    def handle_subscribe(self, params):
        """
        Subscribe authenticated connection on channels.
        """
        subscribe_to = params.get('to')

        if not subscribe_to:
            raise Return((True, None))

        project_id = self.project['_id']

        connections = self.application.connections

        if project_id not in connections:
            connections[project_id] = {}

        if self.user and self.user not in connections:
            connections[project_id][self.user] = {}

        if self.user:
            connections[project_id][self.user][self.uid] = self

        for category_name, channels in six.iteritems(subscribe_to):

            if category_name not in self.categories:
                # attempt to subscribe on not allowed category
                continue

            if not channels or not isinstance(channels, list):
                # attempt to subscribe without channels provided
                continue

            category_id = self.categories[category_name]['_id']

            allowed_channels = self.permissions.get(
                category_name) if self.permissions else []

            for channel in channels:

                if not isinstance(allowed_channels, list):
                    continue

                if allowed_channels and channel not in allowed_channels:
                    # attempt to subscribe on not allowed channel
                    continue

                channel_to_subscribe = rpc.create_channel_name(
                    project_id, category_id, channel)
                self.sub_stream.setsockopt_string(zmq.SUBSCRIBE,
                                                  six.u(channel_to_subscribe))

                if category_name not in self.channels:
                    self.channels[category_name] = {}
                self.channels[category_name][channel_to_subscribe] = True

        raise Return((True, None))

    @coroutine
    def handle_unsubscribe(self, params):
        unsubscribe_from = params.get('from')

        if not unsubscribe_from:
            raise Return((True, None))

        project_id = self.project['_id']

        for category_name, channels in six.iteritems(unsubscribe_from):

            if category_name not in self.categories:
                # attempt to unsubscribe from not allowed category
                continue

            if not channels or not isinstance(channels, list):
                # attempt to unsubscribe from unknown channels
                continue

            category_id = self.categories[category_name]['_id']

            for channel in channels:

                allowed_channels = self.permissions[
                    category_name] if self.permissions else []

                if allowed_channels and channel not in allowed_channels:
                    # attempt to unsubscribe from not allowed channel
                    continue

                channel_to_unsubscribe = rpc.create_channel_name(
                    project_id, category_id, channel)
                self.sub_stream.setsockopt_string(
                    zmq.UNSUBSCRIBE, six.u(channel_to_unsubscribe))

                try:
                    del self.channels[category_name][channel_to_unsubscribe]
                except KeyError:
                    pass

        raise Return((True, None))

    @coroutine
    def handle_broadcast(self, params):

        category = params.get('category')
        channel = params.get('channel')

        if category not in self.categories:
            raise Return(
                (None, 'category does not exist or permission denied'))

        if category not in self.bidirectional_categories:
            raise Return((None, 'one-way category'))

        allowed_channels = self.permissions.get(
            category) if self.permissions else []

        if allowed_channels and channel not in allowed_channels:
            # attempt to broadcast into not allowed channel
            raise Return((None, 'channel permission denied'))

        result, error = yield rpc.process_broadcast(
            self.application, self.project, self.bidirectional_categories,
            params)

        raise Return((result, error))

    @coroutine
    def on_centrifuge_connection_message(self, message):
        """
        Called when message from client received.
        """
        try:
            data = json_decode(message)
        except ValueError:
            self.send_ack(error='malformed JSON data')
            raise Return(False)

        try:
            validate(data, req_schema)
        except ValidationError as e:
            self.send_ack(error=str(e))

        msg_id = data.get('id', None)
        method = data.get('method')
        params = data.get('params')

        if method != 'auth' and not self.is_authenticated:
            self.send_ack(error='unauthorized')
            raise Return(True)

        func = getattr(self, 'handle_%s' % method, None)

        if not func:
            self.send_ack(msg_id=msg_id,
                          method=method,
                          error="unknown method %s" % method)

        try:
            validate(params, client_params_schema[method])
        except ValidationError as e:
            self.send_ack(msg_id=msg_id, method=method, error=str(e))
            raise Return(True)

        result, error = yield func(params)

        self.send_ack(msg_id=msg_id, method=method, result=result, error=error)

        raise Return(True)

    def start_heartbeat(self):
        """
        In ideal case we work with websocket connections with heartbeat available
        by default. But there are lots of other transports whose heartbeat must be
        started manually. Do it here.
        """
        if isinstance(self, SockJSConnection):
            if self.session:
                if self.session.transport_name != 'rawwebsocket':
                    self.session.start_heartbeat()
            else:
                self.close_connection()

    def on_message_published(self, message):
        """
        Called when message received from one of channels client subscribed to.
        """
        actual_message = message[0]
        if six.PY3:
            actual_message = actual_message.decode()
        self.send_message(
            actual_message.split(rpc.CHANNEL_DATA_SEPARATOR, 1)[1])

    def clean_up(self):
        """
        Unsubscribe connection from channels, clean up zmq sockets.
        """
        if hasattr(self, 'sub_stream') and not self.sub_stream.closed():
            self.sub_stream.on_recv(None)
            self.sub_stream.close()

        if not self.is_authenticated:
            return

        project_id = self.project['_id']

        connections = self.application.connections

        self.channels = None

        if not project_id in connections:
            return

        if not self.user in connections[project_id]:
            return

        try:
            del connections[project_id][self.user][self.uid]
        except KeyError:
            pass

        # clean connections
        if not connections[project_id][self.user]:
            try:
                del connections[project_id][self.user]
            except KeyError:
                pass
            if not connections[project_id]:
                try:
                    del connections[project_id]
                except KeyError:
                    pass

    def on_centrifuge_connection_open(self):
        logger.info('client connected')
        self.is_authenticated = False

    def on_centrifuge_connection_close(self):
        logger.info('client disconnected')
        self.clean_up()
Пример #6
0
class ZmqPubSub(object):
    """
    This class manages application PUB/SUB logic.
    """
    def __init__(self, application):
        self.application = application
        self.subscriptions = {}
        self.sub_stream = None

    def init_sockets(self):
        self.zmq_context = zmq.Context()
        options = self.application.settings['options']

        self.zmq_pub_sub_proxy = options.zmq_pub_sub_proxy

        # create PUB socket to publish instance events into it
        publish_socket = self.zmq_context.socket(zmq.PUB)

        # do not try to send messages after closing
        publish_socket.setsockopt(zmq.LINGER, 0)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xsub = options.zmq_xsub
            publish_socket.connect(self.zmq_xsub)
        else:

            # application started without XPUB/XSUB proxy
            if options.zmq_pub_port_shift:
                # calculate zmq pub port number
                zmq_pub_port = options.port - options.zmq_pub_port_shift
            else:
                zmq_pub_port = options.zmq_pub_port

            self.zmq_pub_port = zmq_pub_port

            publish_socket.bind(
                "tcp://%s:%s" %
                (options.zmq_pub_listen, str(self.zmq_pub_port)))

        # wrap pub socket into ZeroMQ stream
        self.pub_stream = ZMQStream(publish_socket)

        # create SUB socket listening to all events from all app instances
        subscribe_socket = self.zmq_context.socket(zmq.SUB)

        if self.zmq_pub_sub_proxy:
            # application started with XPUB/XSUB proxy
            self.zmq_xpub = options.zmq_xpub
            subscribe_socket.connect(self.zmq_xpub)
        else:
            # application started without XPUB/XSUB proxy
            self.zmq_sub_address = options.zmq_sub_address
            for address in self.zmq_sub_address:
                subscribe_socket.connect(address)

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE,
                                           six.u(CONTROL_CHANNEL))

        subscribe_socket.setsockopt_string(zmq.SUBSCRIBE, six.u(ADMIN_CHANNEL))

        def listen_socket():
            # wrap sub socket into ZeroMQ stream and set its on_recv callback
            self.sub_stream = ZMQStream(subscribe_socket)
            self.sub_stream.on_recv(self.dispatch_published_message)

        tornado.ioloop.IOLoop.instance().add_callback(listen_socket)

        if self.zmq_pub_sub_proxy:
            logger.info("ZeroMQ XPUB: {0}, XSUB: {1}".format(
                self.zmq_xpub, self.zmq_xsub))
        else:
            logger.info("ZeroMQ PUB - {0}; subscribed to {1}".format(
                self.zmq_pub_port, self.zmq_sub_address))

    def publish(self, channel, message, method=None):
        """
        Publish message into channel of stream.
        """
        method = method or DEFAULT_PUBLISH_METHOD
        to_publish = [utf8(channel), utf8(method), utf8(message)]
        self.pub_stream.send_multipart(to_publish)

    def get_subscription_key(self, project_id, namespace, channel):
        """
        Create subscription name to catch messages from specific
        project, namespace and channel.
        """
        return str(
            CHANNEL_NAME_SEPARATOR.join(
                [project_id, namespace, channel, CHANNEL_SUFFIX]))

    @coroutine
    def dispatch_published_message(self, multipart_message):
        """
        Got message, decide what is it and dispatch into right
        application handler.
        """
        channel = multipart_message[0]
        method = multipart_message[1]
        message_data = multipart_message[2]
        if six.PY3:
            message_data = message_data.decode()
        if channel == CONTROL_CHANNEL:
            yield self.handle_control_message(message_data)
        elif channel == ADMIN_CHANNEL:
            yield self.handle_admin_message(message_data)
        else:
            yield self.handle_channel_message(channel, method, message_data)

    @coroutine
    def handle_admin_message(self, message):
        for uid, connection in six.iteritems(
                self.application.admin_connections):
            if uid in self.application.admin_connections:
                connection.send(message)

    @coroutine
    def handle_channel_message(self, channel, method, message):
        if channel not in self.subscriptions:
            raise Return((True, None))

        response = Response(method=method, body=message)
        prepared_response = response.as_message()

        for uid, client in six.iteritems(self.subscriptions[channel]):
            if channel in self.subscriptions and uid in self.subscriptions[
                    channel]:
                client.send(prepared_response)

    @coroutine
    def handle_control_message(self, message):
        """
        Handle control message.
        """
        message = json_decode(message)

        app_id = message.get("app_id")
        method = message.get("method")
        params = message.get("params")

        if app_id and app_id == self.application.uid:
            # application id must be set when we don't want to do
            # make things twice for the same application. Setting
            # app_id means that we don't want to process control
            # message when it is appear in application instance if
            # application uid matches app_id
            raise Return((True, None))

        func = getattr(self.application, 'handle_%s' % method, None)
        if not func:
            raise Return((None, 'method not found'))

        result, error = yield func(params)
        raise Return((result, error))

    def add_subscription(self, project_id, namespace_name, channel, client):
        """
        Subscribe application on channel if necessary and register client
        to receive messages from that channel.
        """
        subscription_key = self.get_subscription_key(project_id,
                                                     namespace_name, channel)
        self.sub_stream.setsockopt_string(zmq.SUBSCRIBE,
                                          six.u(subscription_key))

        if subscription_key not in self.subscriptions:
            self.subscriptions[subscription_key] = {}

        self.subscriptions[subscription_key][client.uid] = client

    def remove_subscription(self, project_id, namespace_name, channel, client):
        """
        Unsubscribe application from channel if necessary and unregister client
        from receiving messages from that channel.
        """
        subscription_key = self.get_subscription_key(project_id,
                                                     namespace_name, channel)

        try:
            del self.subscriptions[subscription_key][client.uid]
        except KeyError:
            pass

        try:
            if not self.subscriptions[subscription_key]:
                self.sub_stream.setsockopt_string(zmq.UNSUBSCRIBE,
                                                  six.u(subscription_key))
                del self.subscriptions[subscription_key]
        except KeyError:
            pass