def test_get_api_fake_context(self): with self.assertRaises(Exception) as cm: get_api("Fake context") self.assertEqual("Invalid api context provided: 'Fake context'.", str(cm.exception))
def test_get_api_websockets(self, import_module_mock): fake_loop = MagicMock() fake_plugins = [ *[FakeModule() for _ in range(len(builtin_module_names))], FakeModule(), FakeModule(), FakeModule(), ] import_module_mock.side_effect = fake_plugins white_list = ["a", "b", "c"] async def propagate(self, event_queue): event = await event_queue.get() async def interact(self): pass api_instance = get_api( ApiContext.websocket, plugin_whitelist=white_list, loop=fake_loop, interact=interact, propagate=propagate, host="localhost", port="8080", ) self.assertIsInstance(api_instance, patchable_api.MsaApi) self.assertEqual(api_instance.context, ApiContext.websocket) self.assertIsInstance(api_instance.client, api_clients.ApiWebsocketClient) self.assertTrue("ping" in api_instance)
def test_get_api_local_no_plugin_whitelist(self): fake_loop = MagicMock() with self.assertRaises(Exception) as cm: api_instance = get_api(ApiContext.local, loop=fake_loop) self.assertEqual( "get_api: plugin_whitelist cannot be none when **kwargs are provided to get_api.", str(cm.exception), )
def test_get_api_websocket_context_no_args(self): with self.assertRaises(Exception) as cm: api_instance = get_api(ApiContext.websocket) self.assertEqual( "ApiPatcher: api_client cannot be None when context 'ApiContext.websocket' has never been patched and loaded.", str(cm.exception), )
def start(self): self.api = get_api( ApiContext.websocket, self.config["plugin_modules"], loop=self.loop, interact=self.interact, propagate=self.propagate, host="localhost", port=8080, ) self.globals["msa_api"] = self.api self.loop.run_until_complete(self._start())
def load(self): if self.api is not None: return self.api self.api = get_api( ApiContext.rest, self.config["plugin_modules"], host=self.host, port=self.port, ) run_async(self.api.client.connect()) return self.api
def test_get_api_local(self, import_module_mock): fake_loop = MagicMock() fake_plugins = [ *[FakeModule() for _ in range(len(builtin_module_names))], FakeModule(), FakeModule(), FakeModule(), ] import_module_mock.side_effect = fake_plugins white_list = ["a", "b", "c"] api_instance = get_api(ApiContext.local, plugin_whitelist=white_list, loop=fake_loop) self.assertIsInstance(api_instance, patchable_api.MsaApi) self.assertEqual(api_instance.context, ApiContext.local) self.assertIsInstance(api_instance.client, api_clients.ApiLocalClient) self.assertTrue("ping" in api_instance)
def __init__(self, loop=None): if ScriptExecutionManager.shared_state is None: ScriptExecutionManager.shared_state = {} self.__dict__ = ScriptExecutionManager.shared_state root_logger = logging.getLogger("msa") self.logger = root_logger.getChild( "msa.builtins.scripting.script_execution_manager.ScriptManager" ) self.loop = loop self.running_scripts = set() self.scheduled_scripts = {} self.local_api = get_api(ApiContext.local) self.globals = {"msa_api": self.local_api} self.func_locals = {} self.locals = {} else: self.__dict__ = ScriptExecutionManager.shared_state
def test_get_api_websockets_no_plugin_whitelist(self): fake_loop = MagicMock() async def propagate(self, event_queue): event = await event_queue.get() async def interact(self): pass with self.assertRaises(Exception) as cm: api_instance = get_api( ApiContext.websocket, loop=fake_loop, interact=interact, propagate=propagate, host="localhost", port="8080", ) self.assertEqual( "get_api: plugin_whitelist cannot be none when **kwargs are provided to get_api.", str(cm.exception), )
def init(self, loop, cli_config, route_adapter): """Initializes the supervisor. Parameters ---------- loop : Asynio Event Loop An asyncio event loop the supervisor should use. cli_config: Dict A dictionary containing configuration options derived from the command line interface. route_adapter: ** fix docstrings ** """ if not os.environ.get("TEST"): self.loop = loop self.event_bus = EventBus(self.loop) self.event_queue = asyncio.Queue(self.loop) # block getting a loop if we are running unit tests # helps suppress a warning. # ### PLACEHOLDER - Load Configuration file here -- self.config_manager = ConfigManager(cli_config["cli_overrides"]) config = self.config_manager.get_config() client_api_binder = get_api(ApiContext.local, config["plugin_modules"], loop=loop) server_api_binder = route_adapter # Initialize logging self.init_logging(config["logging"]) plugin_names = config["plugin_modules"] # ### Loading Modules self.logger.info("Loading modules.") # load builtin modules self.logger.debug("Loading builtin modules.") bultin_modules = load_builtin_modules() self.logger.debug("Finished loading builtin modules.") # load plugin modules self.logger.debug("Loading plugin modules.") plugin_modules = load_plugin_modules(plugin_names) self.logger.debug("Finished loading plugin modules.") self.logger.info("Finished loading modules.") self.loaded_modules = bultin_modules + plugin_modules # ### Registering Handlers self.logger.info("Registering handlers.") for module in self.loaded_modules: # Note client api registration is handled by the patcher self.logger.debug( "Registering server api endpoints for module msa.{}".format( module.__name__)) if hasattr(module, "register_server_api") and callable( module.register_server_api): module.register_server_api(server_api_binder) if hasattr(module, "entities_list") and isinstance( module.entities_list, list): __models__.extend(module.entities_list) module_name_tail = module.__name__.split(".")[-1] module_config = config["module_config"].get(module_name_tail, None) if not (hasattr(module, "config_schema") and isinstance(module.config_schema, Schema)): raise Exception( "All modules must define a `config_schema` property that is an instance of " "schema.Schema") self.logger.debug( f"Validating module {module.__name__} config schema.") validated_config = module.config_schema.validate(module_config) self.logger.debug("Registering handlers for module msa.{}".format( module.__name__)) for handler in module.handler_factories: namespace = "{}.{}".format(module.__name__[4:], handler.__name__) full_namespace = "msa.{}".format(namespace) self.logger.debug( "Registering handler: msa.{}".format(namespace)) handler_logger = self.root_logger.getChild(namespace) self.loggers[full_namespace] = handler_logger inited_handler = handler(self.loop, self.event_bus, handler_logger, validated_config) self.initialized_event_handlers.append(inited_handler) self.handler_lookup[handler] = inited_handler self.logger.debug( "Finished registering handler: {}".format(full_namespace)) self.logger.debug( "Finished registering handlers for module {}".format( module.__name__)) self.logger.info("Finished registering handlers.") self.apply_granular_log_levels( config["logging"]["granular_log_levels"])
from prompt_toolkit.shortcuts.prompt import PromptSession from prompt_toolkit.eventloop.defaults import use_asyncio_event_loop import msa from msa.api import get_api, run_async from msa.api.context import ApiContext use_asyncio_event_loop() loop = asyncio.get_event_loop() host = "localhost" port = 8080 plugins = [] # a very simple conversational client api = get_api(ApiContext.rest, plugins, host=host, port=port) async def startup_check(): await api.client.connect() try: await api.check_connection() expected_version = "0.1.0" server_version = await api.get_version() if expected_version != server_version: print("Server version does not match required version!") print("Expected version:", expected_version) print("Server version: ", server_version)
def test_get_api_no_context(self): with self.assertRaises(Exception) as cm: get_api(None) self.assertEqual("get_api: context cannot be None.", str(cm.exception))