Example #1
0
def run_command_from_args(args=None, **extra):
    parsed_args = parse_args(args)

    # Store the args we received away for later. The
    # bus creation process will use them as overrides.
    # This is a bit of hack to allow command line overrides
    # while having no control over bus instantiation (because
    # the application developer does that in bus.py)
    COMMAND_PARSED_ARGS.clear()
    COMMAND_PARSED_ARGS.update(dict(parsed_args._get_kwargs()))

    if hasattr(parsed_args, "config_file"):
        config = load_config(parsed_args)
        plugin_registry = PluginRegistry()
        plugin_registry.autoload_plugins(config)
    else:
        # Command didn't set up any of the common args (via setup_common_arguments()),
        # which are needed to load up the config. As the command didn't set these up,
        # then assume the command doesn't want to be passed the config & plugin_registry
        config = None
        plugin_registry = None

    try:
        if config is None:
            # Don't pass the config & plugin_registry, see above comment
            parsed_args.func(parsed_args, **extra)
        else:
            parsed_args.func(parsed_args, config, plugin_registry, **extra)

    except FailedToImportBusModule as e:
        sys.stderr.write(f"{RED}{e}{RESET}\n")
Example #2
0
def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description="Lightbus management command.")

    subparsers = parser.add_subparsers(help="Commands", dest="subcommand")
    subparsers.required = True

    # Allow each command to set up its own arguments
    lightbus.commands.run.Command().setup(parser, subparsers)
    lightbus.commands.shell.Command().setup(parser, subparsers)
    lightbus.commands.dump_schema.Command().setup(parser, subparsers)
    lightbus.commands.dump_config_schema.Command().setup(parser, subparsers)
    lightbus.commands.inspect.Command().setup(parser, subparsers)
    lightbus.commands.version.Command().setup(parser, subparsers)

    # Create a temporary plugin registry in order to run the before_parse_args hook
    plugin_registry = PluginRegistry()
    plugin_registry.autoload_plugins(config=Config.load_dict({}))

    block(
        plugin_registry.execute_hook("before_parse_args",
                                     parser=parser,
                                     subparsers=subparsers),
        timeout=5,
    )
    args = parser.parse_args(sys.argv[1:] if args is None else args)
    # Note that we don't have an after_parse_args plugin hook. Instead we use the receive_args
    # hook which is called once we have instantiated our plugins

    return args
Example #3
0
def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description="Lightbus management command.")
    parser.add_argument(
        "--service-name",
        "-s",
        help="Name of service in which this process resides. YOU SHOULD "
        "LIKELY SET THIS IN PRODUCTION. Can also be set using the "
        "LIGHTBUS_SERVICE_NAME environment. Will default to a random string.",
    )
    parser.add_argument(
        "--process-name",
        "-p",
        help=
        "A unique name of this process within the service. Can also be set using the "
        "LIGHTBUS_PROCESS_NAME environment. Will default to a random string.",
    )
    parser.add_argument("--config",
                        dest="config_file",
                        help="Config file to load, JSON or YAML",
                        metavar="FILE")
    parser.add_argument(
        "--log-level",
        help="Set the log level. Overrides any value set in config. "
        "One of debug, info, warning, critical, exception.",
        metavar="LOG_LEVEL",
    )

    subparsers = parser.add_subparsers(help="Commands", dest="subcommand")
    subparsers.required = True

    lightbus.commands.run.Command().setup(parser, subparsers)
    lightbus.commands.shell.Command().setup(parser, subparsers)
    lightbus.commands.dump_schema.Command().setup(parser, subparsers)
    lightbus.commands.dump_schema.Command().setup(parser, subparsers)
    lightbus.commands.dump_config_schema.Command().setup(parser, subparsers)
    lightbus.commands.inspect.Command().setup(parser, subparsers)

    # Create a temporary plugin registry in order to run the before_parse_args hook
    plugin_registry = PluginRegistry()
    plugin_registry.autoload_plugins(config=Config.load_dict({}))

    block(
        plugin_registry.execute_hook("before_parse_args",
                                     parser=parser,
                                     subparsers=subparsers),
        timeout=5,
    )
    args = parser.parse_args(sys.argv[1:] if args is None else args)
    # Note that we don't have an after_parse_args plugin hook. Instead we use the receive_args
    # hook which is called once we have instantiated our plugins

    return args
Example #4
0
def test_plugin_enabled(plugin_registry: PluginRegistry):
    config = Config.load_dict({
        "plugins": {
            "internal_state": {
                "enabled": True
            },
            "internal_metrics": {
                "enabled": True
            }
        }
    })
    plugin_registry.autoload_plugins(config)
    assert plugin_registry._plugins
Example #5
0
async def test_execute_hook(mocker, plugin_registry: PluginRegistry):
    """Ensure calling execute_hook() calls the method on the plugin"""
    assert not plugin_registry._plugins
    plugin = LightbusPlugin()
    plugin_registry.set_plugins([plugin])

    async def dummy_coroutine(*args, **kwargs):
        pass

    m = mocker.patch.object(plugin, "before_worker_start", return_value=dummy_coroutine())

    await plugin_registry.execute_hook("before_worker_start", client=None, loop=None)
    assert m.called
Example #6
0
def run_command_from_args(args=None, **extra):
    parsed_args = parse_args(args)

    # Store the args we received away for later. The
    # bus creation process will use them as overrides.
    # This is a bit of hack to allow command line overrides
    # while having no control over bus instantiation (because
    # the application developer does that in bus.py)
    COMMAND_PARSED_ARGS.clear()
    COMMAND_PARSED_ARGS.update(dict(parsed_args._get_kwargs()))

    config = load_config(parsed_args)
    plugin_registry = PluginRegistry()
    plugin_registry.autoload_plugins(config)

    parsed_args.func(parsed_args, config, plugin_registry, **extra)
Example #7
0
def test_autoload_plugins(plugin_registry: PluginRegistry):
    config = Config.load_dict(
        {"plugins": {"internal_state": {"enabled": True}, "internal_metrics": {"enabled": True}}}
    )
    assert not plugin_registry._plugins
    assert plugin_registry.autoload_plugins(config)
    assert [type(p) for p in plugin_registry._plugins] == [StatePlugin, MetricsPlugin]
Example #8
0
    def __init__(self, config: "Config", transport_registry: TransportRegistry = None):
        self._listeners = {}  # event listeners
        self._consumers = []  # RPC consumers
        self._background_tasks = []  # Other background tasks added by user
        self._hook_callbacks = defaultdict(list)
        self.config = config
        self.transport_registry = transport_registry
        self.api_registry = Registry()
        self.plugin_registry = PluginRegistry()
        self.schema = None
        self._server_shutdown_queue: janus.Queue = None
        self._shutdown_monitor_task = None
        self.exit_code = 0
        self._closed = False
        self._server_tasks = []
        self.worker = ClientWorker()

        self.worker.start(bus_client=self, after_shutdown=self._handle_worker_shutdown)
        self.__init_worker___()
Example #9
0
    def _handle(self, args, config, plugin_registry: PluginRegistry):
        self.setup_logging(override=getattr(args, "log_level", None),
                           config=config)

        bus_module, bus = self.import_bus(args)

        # TODO: Move to lightbus.create()?
        if args.schema:
            if args.schema == "-":
                # if '-' read from stdin
                source = None
            else:
                source = args.schema
            bus.schema.load_local(source)

        restart_signals = (signal.SIGINT, signal.SIGTERM)

        # Handle incoming signals
        async def signal_handler():
            # Stop handling signals now. If we receive the signal again
            # let the process quit naturally
            for signal_ in restart_signals:
                asyncio.get_event_loop().remove_signal_handler(signal_)

            logger.debug("Caught signal. Stopping main thread event loop")
            bus.client.shutdown_server(exit_code=0)

        for signal_ in restart_signals:
            asyncio.get_event_loop().add_signal_handler(
                signal_, lambda: asyncio.ensure_future(signal_handler()))

        try:
            block(plugin_registry.execute_hook("receive_args", args=args),
                  timeout=5)
            if args.events_only:
                bus.client.run_forever(consume_rpcs=False)
            else:
                bus.client.run_forever()

        finally:
            # Cleanup signal handlers
            for signal_ in restart_signals:
                asyncio.get_event_loop().remove_signal_handler(signal_)

        if bus.client.exit_code:
            sys.exit(bus.client.exit_code)
Example #10
0
    def handle(self,
               args,
               config,
               plugin_registry: PluginRegistry,
               fake_it=False):
        command_utilities.setup_logging(args.log_level or "warning", config)

        try:
            # pylint: disable=unused-import,cyclic-import,import-outside-toplevel
            import bpython
            from bpython.curtsies import main as bpython_main
        except ImportError:  # pragma: no cover
            print(
                "Lightbus shell requires bpython. Run `pip install bpython` to install bpython."
            )
            sys.exit(1)
            return  # noqa

        lightbus_logger = logging.getLogger("lightbus")
        lightbus_logger.setLevel(logging.WARNING)

        bus_module, bus = command_utilities.import_bus(args)
        block(bus.client.lazy_load_now())

        objects = {k: v for k, v in lightbus.__dict__.items() if isclass(v)}
        objects.update(bus=bus)

        block(plugin_registry.execute_hook("receive_args", args=args),
              timeout=5)

        # Ability to not start up the repl is useful for testing
        if not fake_it:
            bpython_main(
                args=["-i", "-q"],
                locals_=objects,
                welcome_message=
                "Welcome to the Lightbus shell. Use `bus` to access your bus.",
            )
Example #11
0
def create(
    config: Union[dict, RootConfig] = None,
    *,
    config_file: str = None,
    service_name: str = None,
    process_name: str = None,
    features: List[Union[Feature, str]] = ALL_FEATURES,
    client_class: Type[BusClient] = BusClient,
    node_class: Type[BusPath] = BusPath,
    plugins=None,
    flask: bool = False,
    **kwargs,
) -> BusPath:
    """
    Create a new bus instance which can be used to access the bus.

    Typically this will be used as follows:

        import lightbus

        bus = lightbus.create()

    This will be a `BusPath` instance. If you wish to access the lower
    level `BusClient` you can do so via `bus.client`.

    Args:
        config (dict, Config): The config object or dictionary to load
        config_file (str): The path to a config file to load (should end in .json or .yaml)
        service_name (str): The name of this service - will be used when creating event consumer groups
        process_name (str): The unique name of this process - used when retrieving unprocessed events following a crash
        client_class (Type[BusClient]): The class from which the bus client will be instantiated
        node_class (BusPath): The class from which the bus path will be instantiated
        plugins (list): A list of plugin instances to load
        flask (bool): Are we using flask? If so we will make sure we don't start lightbus in the reloader process
        **kwargs (): Any additional instantiation arguments to be passed to `client_class`.

    Returns: BusPath

    """
    if flask:
        in_flask_server = sys.argv[0].endswith("flask") and "run" in sys.argv
        if in_flask_server and os.environ.get("WERKZEUG_RUN_MAIN", "").lower() != "true":
            # Flask has a reloader process that shouldn't start a lightbus client
            return

    # Ensure an event loop exists, as creating InternalQueue
    # objects requires that we have one.
    get_event_loop()

    # If were are running via the Lightbus CLI then we may have
    # some command line arguments we need to apply.
    # pylint: disable=cyclic-import,import-outside-toplevel
    from lightbus.commands import COMMAND_PARSED_ARGS

    config_file = COMMAND_PARSED_ARGS.get("config_file", None) or config_file
    service_name = COMMAND_PARSED_ARGS.get("service_name", None) or service_name
    process_name = COMMAND_PARSED_ARGS.get("process_name", None) or process_name

    if config is None:
        config = load_config(
            from_file=config_file, service_name=service_name, process_name=process_name
        )

    if isinstance(config, Mapping):
        config = Config.load_dict(config or {})
    elif isinstance(config, RootConfig):
        config = Config(config)

    transport_registry = kwargs.pop("transport_registry", None) or TransportRegistry().load_config(
        config
    )

    schema = Schema(
        schema_transport=transport_registry.get_schema_transport(),
        max_age_seconds=config.bus().schema.ttl,
        human_readable=config.bus().schema.human_readable,
    )

    error_queue: ErrorQueueType = InternalQueue()

    # Plugin registry

    plugin_registry = PluginRegistry()
    if plugins is None:
        logger.debug("Auto-loading any installed Lightbus plugins...")
        plugin_registry.autoload_plugins(config)
    else:
        logger.debug("Loading explicitly specified Lightbus plugins....")
        plugin_registry.set_plugins(plugins)

    # Hook registry

    hook_registry = HookRegistry(
        error_queue=error_queue, execute_plugin_hooks=plugin_registry.execute_hook
    )

    # API registry

    api_registry = ApiRegistry()
    api_registry.add(LightbusStateApi())
    api_registry.add(LightbusMetricsApi())

    events_queue_client_to_dock = InternalQueue()
    events_queue_dock_to_client = InternalQueue()

    event_client = EventClient(
        api_registry=api_registry,
        hook_registry=hook_registry,
        config=config,
        schema=schema,
        error_queue=error_queue,
        consume_from=events_queue_dock_to_client,
        produce_to=events_queue_client_to_dock,
    )

    event_dock = EventDock(
        transport_registry=transport_registry,
        api_registry=api_registry,
        config=config,
        error_queue=error_queue,
        consume_from=events_queue_client_to_dock,
        produce_to=events_queue_dock_to_client,
    )

    rpcs_queue_client_to_dock = InternalQueue()
    rpcs_queue_dock_to_client = InternalQueue()

    rpc_result_client = RpcResultClient(
        api_registry=api_registry,
        hook_registry=hook_registry,
        config=config,
        schema=schema,
        error_queue=error_queue,
        consume_from=rpcs_queue_dock_to_client,
        produce_to=rpcs_queue_client_to_dock,
    )

    rpc_result_dock = RpcResultDock(
        transport_registry=transport_registry,
        api_registry=api_registry,
        config=config,
        error_queue=error_queue,
        consume_from=rpcs_queue_client_to_dock,
        produce_to=rpcs_queue_dock_to_client,
    )

    client = client_class(
        config=config,
        hook_registry=hook_registry,
        plugin_registry=plugin_registry,
        features=features,
        schema=schema,
        api_registry=api_registry,
        event_client=event_client,
        rpc_result_client=rpc_result_client,
        error_queue=error_queue,
        transport_registry=transport_registry,
        **kwargs,
    )

    # Pass the client to any hooks
    # (use a weakref to prevent circular references)
    hook_registry.set_extra_parameter("client", weakref.proxy(client))

    # We don't do this normally as the docks do not need to be
    # accessed directly, but this is useful in testing
    # TODO: Testing flag removed, but these are only needed in testing.
    #       Perhaps wrap them up in a way that makes this obvious
    client.event_dock = event_dock
    client.rpc_result_dock = rpc_result_dock

    log_welcome_message(
        logger=logger,
        transport_registry=transport_registry,
        schema=schema,
        plugin_registry=plugin_registry,
        config=config,
    )

    return node_class(name="", parent=None, client=client)
Example #12
0
 def handle(self, args, config, plugin_registry: PluginRegistry):
     try:
         self._handle(args, config, plugin_registry)
     except Exception as e:
         block(plugin_registry.execute_hook("exception", e=e), timeout=5)
         raise
Example #13
0
class BusClient(object):
    """Provides a the lower level interface for accessing the bus

    The low-level `BusClient` is less expressive than the interface provided by `BusPath`,
    but does allow for more control in some situations.

    All functionality in `BusPath` is provided by `BusClient`.
    """

    def __init__(self, config: "Config", transport_registry: TransportRegistry = None):
        self._listeners = {}  # event listeners
        self._consumers = []  # RPC consumers
        self._background_tasks = []  # Other background tasks added by user
        self._hook_callbacks = defaultdict(list)
        self.config = config
        self.transport_registry = transport_registry
        self.api_registry = Registry()
        self.plugin_registry = PluginRegistry()
        self.schema = None
        self._server_shutdown_queue: janus.Queue = None
        self._shutdown_monitor_task = None
        self.exit_code = 0
        self._closed = False
        self._server_tasks = []
        self.worker = ClientWorker()

        self.worker.start(bus_client=self, after_shutdown=self._handle_worker_shutdown)
        self.__init_worker___()

    @run_in_worker_thread()
    def __init_worker___(self):
        self.transport_registry = self.transport_registry or TransportRegistry().load_config(
            self.config
        )
        schema = Schema(
            schema_transport=self.transport_registry.get_schema_transport(),
            max_age_seconds=self.config.bus().schema.ttl,
            human_readable=self.config.bus().schema.human_readable,
        )
        self.schema = WorkerProxy(proxied=schema, worker=self.worker)

    def _handle_worker_shutdown(self):
        # This method will be called within the worker thead, but after the worker
        # thread's event loop has stopped

        try:
            # Close _close_async_inner() because, we are in the worker thead.
            block(self._close_async_inner())
        except BusAlreadyClosed:
            # In the case of a clean shutdown the bus will already be closed.
            pass

    @run_in_worker_thread()
    async def setup_async(self, plugins: dict = None):
        """Setup lightbus and get it ready to consume events and/or RPCs

        You should call this manually if you are calling `consume_rpcs()`
        directly. This you be handled for you if you are
        calling `run_forever()`.
        """
        logger.info(
            LBullets(
                "Lightbus is setting up",
                items={
                    "service_name (set with -s or LIGHTBUS_SERVICE_NAME)": Bold(
                        self.config.service_name
                    ),
                    "process_name (with with -p or LIGHTBUS_PROCESS_NAME)": Bold(
                        self.config.process_name
                    ),
                },
            )
        )

        # Log the transport information
        rpc_transport = self.transport_registry.get_rpc_transport("default", default=None)
        result_transport = self.transport_registry.get_result_transport("default", default=None)
        event_transport = self.transport_registry.get_event_transport("default", default=None)
        log_transport_information(
            rpc_transport, result_transport, event_transport, self.schema.schema_transport, logger
        )

        # Log the plugins we have
        if plugins is None:
            logger.debug("Auto-loading any installed Lightbus plugins...")
            self.plugin_registry.autoload_plugins(self.config)
        else:
            logger.debug("Loading explicitly specified Lightbus plugins....")
            self.plugin_registry.set_plugins(plugins)

        if self.plugin_registry._plugins:
            logger.info(
                LBullets(
                    "Loaded the following plugins ({})".format(len(self.plugin_registry._plugins)),
                    items=self.plugin_registry._plugins,
                )
            )
        else:
            logger.info("No plugins loaded")

        # Load schema
        logger.debug("Loading schema...")
        await self.schema.load_from_bus()

        logger.info(
            LBullets(
                "Loaded the following remote schemas ({})".format(len(self.schema.remote_schemas)),
                items=self.schema.remote_schemas.keys(),
            )
        )

        for transport in self.transport_registry.get_all_transports():
            await transport.open()

    def setup(self, plugins: dict = None):
        block(self.setup_async(plugins), timeout=5)

    @assert_not_in_worker_thread()
    def close(self, _stop_worker=True):
        """Close the bus client

        This will cancel all tasks and close all transports/connections
        """
        if self._closed:
            raise BusAlreadyClosed()
        block(self.close_async(_stop_worker=_stop_worker))

    @assert_not_in_worker_thread()
    async def close_async(self, _stop_worker=True):
        """Async version of close()
        """
        try:
            if self._closed:
                raise BusAlreadyClosed()

            await self._close_async_inner()
        finally:
            # Whatever happens, make sure we stop the event loop otherwise the
            # bus thread will keep running and prevent the process for exiting
            if _stop_worker:
                self.worker.shutdown()
                await self.worker.wait_for_shutdown()

    @run_in_worker_thread()
    async def _close_async_inner(self):
        """Handle all aspects of the closing which need to run within the bus worker thread"""
        if self._closed:
            raise BusAlreadyClosed()

        listener_tasks = [task for task in all_tasks() if getattr(task, "is_listener", False)]

        for task in chain(listener_tasks, self._background_tasks):
            try:
                await cancel(task)
            except Exception as e:
                logger.exception(e)

        for transport in self.transport_registry.get_all_transports():
            await transport.close()

        await self.schema.schema_transport.close()

        self._closed = True

    @property
    def loop(self):
        return get_event_loop()

    def run_forever(self, *, consume_rpcs=True):
        self.start_server()

        self._actually_run_forever()
        logger.debug("Main thread event loop was stopped")

        # Stopping the server requires access to the worker,
        # so do this first
        logger.debug("Stopping server")
        self.stop_server()

        # Here we close connections and shutdown the worker thread
        logger.debug("Closing bus")
        self.close()

    def shutdown_server(self, exit_code):
        self._server_shutdown_queue.sync_q.put(exit_code)

    @assert_not_in_worker_thread()
    def start_server(self, consume_rpcs=True):
        """Server startup procedure

        Must be called from within the main thread
        """
        # Ensure an event loop exists
        get_event_loop()

        self._server_shutdown_queue = janus.Queue()
        self._server_tasks = set()

        async def server_shutdown_monitor():
            exit_code = await self._server_shutdown_queue.async_q.get()
            self.exit_code = exit_code
            self.loop.stop()
            self._server_shutdown_queue.async_q.task_done()

        shutdown_monitor_task = asyncio.ensure_future(server_shutdown_monitor())
        shutdown_monitor_task.add_done_callback(make_exception_checker(self, die=True))
        self._shutdown_monitor_task = shutdown_monitor_task

        block(self._start_server_inner())

    @run_in_worker_thread()
    async def _start_server_inner(self, consume_rpcs=True):
        self.api_registry.add(LightbusStateApi())
        self.api_registry.add(LightbusMetricsApi())

        if consume_rpcs:
            logger.info(
                LBullets(
                    "APIs in registry ({})".format(len(self.api_registry.all())),
                    items=self.api_registry.names(),
                )
            )

        # Setup RPC consumption
        consume_rpc_task = None
        if consume_rpcs and self.api_registry.all():
            consume_rpc_task = asyncio.ensure_future(self.consume_rpcs())
            consume_rpc_task.add_done_callback(make_exception_checker(self, die=True))

        # Setup schema monitoring
        monitor_task = asyncio.ensure_future(self.schema.monitor())
        monitor_task.add_done_callback(make_exception_checker(self, die=True))

        logger.info("Executing before_server_start & on_start hooks...")
        await self._execute_hook("before_server_start")
        logger.info("Execution of before_server_start & on_start hooks was successful")

        self._server_tasks = [consume_rpc_task, monitor_task]

    @assert_not_in_worker_thread()
    def stop_server(self):
        block(cancel(self._shutdown_monitor_task))
        block(self._stop_server_inner())

    @run_in_worker_thread()
    async def _stop_server_inner(self):
        # Cancel the tasks we created above
        await cancel(*self._server_tasks)

        logger.info("Executing after_server_stopped & on_stop hooks...")
        await self._execute_hook("after_server_stopped")
        logger.info("Execution of after_server_stopped & on_stop hooks was successful")

    def _actually_run_forever(self):  # pragma: no cover
        """Simply start the loop running forever

        This just makes testing easier as we can mock out this method
        """
        self.loop.run_forever()

    # RPCs

    @run_in_worker_thread()
    async def consume_rpcs(self, apis: List[Api] = None):
        if apis is None:
            apis = self.api_registry.all()

        if not apis:
            raise NoApisToListenOn(
                "No APIs to consume on in consume_rpcs(). Either this method was called with apis=[], "
                "or the API registry is empty."
            )

        # Not all APIs will necessarily be served by the same transport, so group them
        # accordingly
        api_names = [api.meta.name for api in apis]
        api_names_by_transport = self.transport_registry.get_rpc_transports(api_names)

        coroutines = []
        for rpc_transport, transport_api_names in api_names_by_transport:
            transport_apis = list(map(self.api_registry.get, transport_api_names))
            coroutines.append(
                self._consume_rpcs_with_transport(rpc_transport=rpc_transport, apis=transport_apis)
            )

        task = asyncio.ensure_future(asyncio.gather(*coroutines))
        task.add_done_callback(make_exception_checker(self, die=True))
        self._consumers.append(task)

    async def _consume_rpcs_with_transport(
        self, rpc_transport: RpcTransport, apis: List[Api] = None
    ):
        while True:
            try:
                rpc_messages = await rpc_transport.consume_rpcs(apis, bus_client=self)
            except TransportIsClosed:
                return

            for rpc_message in rpc_messages:
                self._validate(rpc_message, "incoming")

                await self._execute_hook("before_rpc_execution", rpc_message=rpc_message)
                try:
                    result = await self.call_rpc_local(
                        api_name=rpc_message.api_name,
                        name=rpc_message.procedure_name,
                        kwargs=rpc_message.kwargs,
                    )
                except SuddenDeathException:
                    # Used to simulate message failure for testing
                    return
                except CancelledError:
                    raise
                except Exception as e:
                    result = e
                else:
                    result = deform_to_bus(result)

                result_message = ResultMessage(result=result, rpc_message_id=rpc_message.id)
                await self._execute_hook(
                    "after_rpc_execution", rpc_message=rpc_message, result_message=result_message
                )

                self._validate(
                    result_message,
                    "outgoing",
                    api_name=rpc_message.api_name,
                    procedure_name=rpc_message.procedure_name,
                )

                await self.send_result(rpc_message=rpc_message, result_message=result_message)

    @run_in_worker_thread()
    async def call_rpc_remote(
        self, api_name: str, name: str, kwargs: dict = frozendict(), options: dict = frozendict()
    ):
        rpc_transport = self.transport_registry.get_rpc_transport(api_name)
        result_transport = self.transport_registry.get_result_transport(api_name)

        kwargs = deform_to_bus(kwargs)
        rpc_message = RpcMessage(api_name=api_name, procedure_name=name, kwargs=kwargs)
        return_path = result_transport.get_return_path(rpc_message)
        rpc_message.return_path = return_path
        options = options or {}
        timeout = options.get("timeout", self.config.api(api_name).rpc_timeout)
        # TODO: rpc_timeout is in three different places in the config!
        #       Fix this. Really it makes most sense for the use if it goes on the
        #       ApiConfig rather than having to repeat it on both the result & RPC
        #       transports.
        self._validate_name(api_name, "rpc", name)

        logger.info("📞  Calling remote RPC {}.{}".format(Bold(api_name), Bold(name)))

        start_time = time.time()
        # TODO: It is possible that the RPC will be called before we start waiting for the response. This is bad.

        self._validate(rpc_message, "outgoing")

        future = asyncio.gather(
            self.receive_result(rpc_message, return_path, options=options),
            rpc_transport.call_rpc(rpc_message, options=options, bus_client=self),
        )

        await self._execute_hook("before_rpc_call", rpc_message=rpc_message)

        try:
            result_message, _ = await asyncio.wait_for(future, timeout=timeout)
            future.result()
        except asyncio.TimeoutError:
            # Allow the future to finish, as per https://bugs.python.org/issue29432
            try:
                await future
                future.result()
            except CancelledError:
                pass

            # TODO: Remove RPC from queue. Perhaps add a RpcBackend.cancel() method. Optional,
            #       as not all backends will support it. No point processing calls which have timed out.
            raise LightbusTimeout(
                f"Timeout when calling RPC {rpc_message.canonical_name} after {timeout} seconds. "
                f"It is possible no Lightbus process is serving this API, or perhaps it is taking "
                f"too long to process the request. In which case consider raising the 'rpc_timeout' "
                f"config option."
            ) from None

        await self._execute_hook(
            "after_rpc_call", rpc_message=rpc_message, result_message=result_message
        )

        if not result_message.error:
            logger.info(
                L(
                    "🏁  Remote call of {} completed in {}",
                    Bold(rpc_message.canonical_name),
                    human_time(time.time() - start_time),
                )
            )
        else:
            logger.warning(
                L(
                    "⚡ Server error during remote call of {}. Took {}: {}",
                    Bold(rpc_message.canonical_name),
                    human_time(time.time() - start_time),
                    result_message.result,
                )
            )
            raise LightbusServerError(
                "Error while calling {}: {}\nRemote stack trace:\n{}".format(
                    rpc_message.canonical_name, result_message.result, result_message.trace
                )
            )

        self._validate(result_message, "incoming", api_name, procedure_name=name)

        return result_message.result

    @run_in_worker_thread()
    async def call_rpc_local(self, api_name: str, name: str, kwargs: dict = frozendict()):
        api = self.api_registry.get(api_name)
        self._validate_name(api_name, "rpc", name)

        start_time = time.time()
        try:
            method = getattr(api, name)
            if self.config.api(api_name).cast_values:
                kwargs = cast_to_signature(kwargs, method)
            result = await run_user_provided_callable(
                method, args=[], kwargs=kwargs, bus_client=self
            )
        except (CancelledError, SuddenDeathException):
            raise
        except Exception as e:
            logging.exception(e)
            logger.warning(
                L(
                    "⚡  Error while executing {}.{}. Took {}",
                    Bold(api_name),
                    Bold(name),
                    human_time(time.time() - start_time),
                )
            )
            raise
        else:
            logger.info(
                L(
                    "⚡  Executed {}.{} in {}",
                    Bold(api_name),
                    Bold(name),
                    human_time(time.time() - start_time),
                )
            )
            return result

    # Events

    @run_in_worker_thread()
    async def fire_event(self, api_name, name, kwargs: dict = None, options: dict = None):
        kwargs = kwargs or {}
        try:
            api = self.api_registry.get(api_name)
        except UnknownApi:
            raise UnknownApi(
                "Lightbus tried to fire the event {api_name}.{name}, but no API named {api_name} was found in the "
                "registry. An API being in the registry implies you are an authority on that API. Therefore, "
                "Lightbus requires the API to be in the registry as it is a bad idea to fire "
                "events on behalf of remote APIs. However, this could also be caused by a typo in the "
                "API name or event name, or be because the API class has not been "
                "registered using bus.client.register_api(). ".format(**locals())
            )

        self._validate_name(api_name, "event", name)

        try:
            event = api.get_event(name)
        except EventNotFound:
            raise EventNotFound(
                "Lightbus tried to fire the event {api_name}.{name}, but the API {api_name} does not "
                "seem to contain an event named {name}. You may need to define the event, you "
                "may also be using the incorrect API. Also check for typos.".format(**locals())
            )

        if set(kwargs.keys()) != _parameter_names(event.parameters):
            raise InvalidEventArguments(
                "Invalid event arguments supplied when firing event. Attempted to fire event with "
                "{} arguments: {}. Event expected {}: {}".format(
                    len(kwargs),
                    sorted(kwargs.keys()),
                    len(event.parameters),
                    sorted(_parameter_names(event.parameters)),
                )
            )

        kwargs = deform_to_bus(kwargs)
        event_message = EventMessage(
            api_name=api.meta.name, event_name=name, kwargs=kwargs, version=api.meta.version
        )

        self._validate(event_message, "outgoing")

        event_transport = self.transport_registry.get_event_transport(api_name)
        await self._execute_hook("before_event_sent", event_message=event_message)
        logger.info(L("📤  Sending event {}.{}".format(Bold(api_name), Bold(name))))
        await event_transport.send_event(event_message, options=options, bus_client=self)
        await self._execute_hook("after_event_sent", event_message=event_message)

    async def listen_for_event(
        self, api_name, name, listener, listener_name: str, options: dict = None
    ):
        return await self.listen_for_events(
            [(api_name, name)], listener, listener_name=listener_name, options=options
        )

    @run_in_worker_thread()
    async def listen_for_events(
        self, events: List[Tuple[str, str]], listener, listener_name: str, options: dict = None
    ):
        self._sanity_check_listener(listener)

        for api_name, name in events:
            self._validate_name(api_name, "event", name)

        event_listener = _EventListener(
            events=events,
            listener_callable=listener,
            listener_name=listener_name,
            options=options,
            bus_client=self,
        )
        event_listener.make_task()

    # Results

    @run_in_worker_thread()
    async def send_result(self, rpc_message: RpcMessage, result_message: ResultMessage):
        result_transport = self.transport_registry.get_result_transport(rpc_message.api_name)
        return await result_transport.send_result(
            rpc_message, result_message, rpc_message.return_path, bus_client=self
        )

    @run_in_worker_thread()
    async def receive_result(self, rpc_message: RpcMessage, return_path: str, options: dict):
        result_transport = self.transport_registry.get_result_transport(rpc_message.api_name)
        return await result_transport.receive_result(
            rpc_message, return_path, options, bus_client=self
        )

    @contextlib.contextmanager
    def _register_listener(self, events: List[Tuple[str, str]]):
        """A context manager to help keep track of what the bus is listening for"""
        logger.info(
            LBullets(f"Registering listener for", items=[Bold(f"{a}.{e}") for a, e in events])
        )

        for api_name, event_name in events:
            key = (api_name, event_name)
            self._listeners.setdefault(key, 0)
            self._listeners[key] += 1

        yield

        for api_name, event_name in events:
            key = (api_name, event_name)
            self._listeners[key] -= 1
            if not self._listeners[key]:
                self._listeners.pop(key)

    @property
    def listeners(self) -> List[Tuple[str, str]]:
        """Returns a list of events which are currently being listened to

        Each value is a tuple in the form `(api_name, event_name)`.
        """
        return list(self._listeners.keys())

    def _validate(self, message: Message, direction: str, api_name=None, procedure_name=None):
        if direction not in ("incoming", "outgoing"):
            raise AssertionError("Invalid direction specified")

        # Result messages do not carry the api or procedure name, so allow them to be
        # specified manually
        api_name = getattr(message, "api_name", api_name)
        event_or_rpc_name = getattr(message, "procedure_name", None) or getattr(
            message, "event_name", procedure_name
        )
        api_config = self.config.api(api_name)
        strict_validation = api_config.strict_validation

        if not getattr(api_config.validate, direction):
            return

        if api_name not in self.schema:
            if strict_validation:
                raise UnknownApi(
                    f"Validation is enabled for API named '{api_name}', but there is no schema present for this API. "
                    f"Validation is therefore not possible. You are also seeing this error because the "
                    f"'strict_validation' setting is enabled. Disabling this setting will turn this exception "
                    f"into a warning. "
                )
            else:
                logger.warning(
                    f"Validation is enabled for API named '{api_name}', but there is no schema present for this API. "
                    f"Validation is therefore not possible. You can force this to be an error by enabling "
                    f"the 'strict_validation' config option. You can silence this message by disabling validation "
                    f"for this API using the 'validate' option."
                )
                return

        if isinstance(message, (RpcMessage, EventMessage)):
            self.schema.validate_parameters(api_name, event_or_rpc_name, message.kwargs)
        elif isinstance(message, ResultMessage):
            self.schema.validate_response(api_name, event_or_rpc_name, message.result)

    @run_in_worker_thread()
    def add_background_task(
        self, coroutine: Union[Coroutine, asyncio.Future], cancel_on_close=True
    ):
        """Run a coroutine in the background

        The provided coroutine will be run in the background once
        Lightbus startup is complete.

        The coroutine will be cancelled when the bus client is closed if
        `cancel_on_close` is set to `True`.

        The Lightbus process will exit if the coroutine raises an exception.
        See lightbus.utilities.async_tools.check_for_exception() for details.
        """
        task = asyncio.ensure_future(coroutine)
        task.add_done_callback(make_exception_checker(self, die=True))
        if cancel_on_close:
            # Store task for closing later
            self._background_tasks.append(task)

    # Utilities

    def _validate_name(self, api_name: str, type_: str, name: str):
        """Validate that the given RPC/event name is ok to use"""
        if not name:
            raise InvalidName(f"Empty {type_} name specified when calling API {api_name}")

        if name.startswith("_"):
            raise InvalidName(
                f"You can not use '{api_name}.{name}' as an {type_} because it starts with an underscore. "
                f"API attributes starting with underscores are not available on the bus."
            )

    def _sanity_check_listener(self, listener):
        if not callable(listener):
            raise InvalidEventListener(
                f"The specified event listener {listener} is not callable. Perhaps you called the function rather "
                f"than passing the function itself?"
            )

        total_positional_args = 0
        has_variable_positional_args = False  # Eg: *args
        for parameter in inspect.signature(listener).parameters.values():
            if parameter.kind in (
                inspect.Parameter.POSITIONAL_ONLY,
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
            ):
                total_positional_args += 1
            elif parameter.kind == inspect.Parameter.VAR_POSITIONAL:
                has_variable_positional_args = True

        if has_variable_positional_args:
            return

        if not total_positional_args:
            raise InvalidEventListener(
                f"The specified event listener {listener} must take at one positional argument. "
                f"This will be the event message. For example: "
                f"my_listener(event_message, other, ...)"
            )

    # Hooks

    async def _execute_hook(self, name, **kwargs):
        # Hooks that need to run before plugins
        for callback in self._hook_callbacks[(name, True)]:
            await run_user_provided_callable(
                callback, args=[], kwargs=dict(client=self, **kwargs), bus_client=self
            )

        await self.plugin_registry.execute_hook(name, client=self, **kwargs)

        # Hooks that need to run after plugins
        for callback in self._hook_callbacks[(name, False)]:
            await run_user_provided_callable(
                callback, args=[], kwargs=dict(client=self, **kwargs), bus_client=self
            )

    def _register_hook_callback(self, name, fn, before_plugins=False):
        self._hook_callbacks[(name, bool(before_plugins))].append(fn)

    def _make_hook_decorator(self, name, before_plugins=False, callback=None):
        if callback and not callable(callback):
            raise AssertionError("The provided callback is not callable")
        if callback:
            self._register_hook_callback(name, callback, before_plugins)
        else:

            def hook_decorator(fn):
                self._register_hook_callback(name, fn, before_plugins)
                return fn

            return hook_decorator

    def on_start(self, callback=None, *, before_plugins=False):
        """Alias for before_server_start"""
        return self.before_server_start(callback, before_plugins=before_plugins)

    def on_stop(self, callback=None, *, before_plugins=False):
        """Alias for after_server_stopped"""
        return self.before_server_start(callback, before_plugins=before_plugins)

    def before_server_start(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("before_server_start", before_plugins, callback)

    def after_server_stopped(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("after_server_stopped", before_plugins, callback)

    def before_rpc_call(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("before_rpc_call", before_plugins, callback)

    def after_rpc_call(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("after_rpc_call", before_plugins, callback)

    def before_rpc_execution(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("before_rpc_execution", before_plugins, callback)

    def after_rpc_execution(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("after_rpc_execution", before_plugins, callback)

    def before_event_sent(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("before_event_sent", before_plugins, callback)

    def after_event_sent(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("after_event_sent", before_plugins, callback)

    def before_event_execution(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("before_event_execution", before_plugins, callback)

    def after_event_execution(self, callback=None, *, before_plugins=False):
        return self._make_hook_decorator("after_event_execution", before_plugins, callback)

    # Scheduling

    def every(
        self,
        *,
        seconds=0,
        minutes=0,
        hours=0,
        days=0,
        also_run_immediately=False,
        **timedelta_extra,
    ):
        """ Call a coroutine at the specified interval

        This is a simple scheduling mechanism which you can use in your bus module to setup
        recurring tasks. For example:

            bus = lightbus.create()

            @bus.client.every(seconds=30)
            def my_func():
                print("Hello")

        This can also be used to decorate async functions. In this case the function will be awaited.

        Note that the timing is best effort and is not guaranteed. That being said, execution
        time is accounted for.

        See Also:

            @bus.client.schedule()
        """
        td = timedelta(seconds=seconds, minutes=minutes, hours=hours, days=days, **timedelta_extra)

        if td.total_seconds() == 0:
            raise InvalidSchedule(
                "The @bus.client.every() decorator must be provided with a non-zero time argument. "
                "Ensure you are passing at least one time argument, and that it has a non-zero value."
            )

        # TODO: There is an argument that the backgrounding of this should be done only after
        #       on_start() has been fired. Otherwise this will be run before the on_start() setup
        #       has happened in cases where also_run_immediately=True.
        def wrapper(f):
            coroutine = call_every(  # pylint: assignment-from-no-return
                callback=f, timedelta=td, also_run_immediately=also_run_immediately, bus_client=self
            )
            self.add_background_task(coroutine)
            return f

        return wrapper

    def schedule(self, schedule: "Job", also_run_immediately=False):
        def wrapper(f):
            coroutine = call_on_schedule(
                callback=f,
                schedule=schedule,
                also_run_immediately=also_run_immediately,
                bus_client=self,
            )
            self.add_background_task(coroutine)
            return f

        return wrapper

    # API registration

    def register_api(self, api: Api):
        block(self.register_api_async(api), timeout=5)

    @run_in_worker_thread()
    async def register_api_async(self, api: Api):
        self.api_registry.add(api)
        await self.schema.add_api(api)
Example #14
0
def plugin_registry():
    return PluginRegistry()
Example #15
0
def test_is_plugin_loaded(plugin_registry: PluginRegistry):
    assert plugin_registry.is_plugin_loaded(LightbusPlugin) == False
    plugin_registry.set_plugins([LightbusPlugin()])
    assert plugin_registry.is_plugin_loaded(LightbusPlugin) == True
Example #16
0
def test_autoload_plugins(plugin_registry: PluginRegistry):
    config = Config.load_dict({})
    assert not plugin_registry._plugins
    assert plugin_registry.autoload_plugins(config)
    assert [type(p)
            for p in plugin_registry._plugins] == [StatePlugin, MetricsPlugin]
Example #17
0
def test_manually_set_plugins(plugin_registry: PluginRegistry):
    assert not plugin_registry._plugins
    p1 = LightbusPlugin()
    p2 = LightbusPlugin()
    plugin_registry.set_plugins([p1, p2])
    assert plugin_registry._plugins == [p1, p2]
Example #18
0
    def _handle(self, args, config, plugin_registry: PluginRegistry):
        command_utilities.setup_logging(override=getattr(
            args, "log_level", None),
                                        config=config)

        bus_module, bus = command_utilities.import_bus(args)

        # Convert only & skip into a list of features to enable
        if args.only or args.skip:
            if args.only:
                features = args.only
            else:
                features = self.all_features

            for skip_feature in args.skip or []:
                if skip_feature in features:
                    features.remove(skip_feature)
        elif os.environ.get("LIGHTBUS_FEATURES"):
            features = csv_type(os.environ.get("LIGHTBUS_FEATURES"))
        else:
            features = ALL_FEATURES

        bus.client.set_features(features)

        # TODO: Move to lightbus.create()?
        if args.schema:
            if args.schema == "-":
                # if '-' read from stdin
                source = None
            else:
                source = args.schema
            bus.schema.load_local(source)

        restart_signals = (signal.SIGINT, signal.SIGTERM)

        # Handle incoming signals
        async def signal_handler():
            # Stop handling signals now. If we receive the signal again
            # let the process quit naturally
            for signal_ in restart_signals:
                asyncio.get_event_loop().remove_signal_handler(signal_)

            logger.debug("Caught signal. Stopping main thread event loop")
            bus.client.shutdown_server(exit_code=0)

        for signal_ in restart_signals:
            asyncio.get_event_loop().add_signal_handler(
                signal_, lambda: asyncio.ensure_future(signal_handler()))

        try:
            block(plugin_registry.execute_hook("receive_args", args=args),
                  timeout=5)
            bus.client.run_forever()

        finally:
            # Cleanup signal handlers
            for signal_ in restart_signals:
                asyncio.get_event_loop().remove_signal_handler(signal_)

        if bus.client.exit_code:
            sys.exit(bus.client.exit_code)