async def test_publish(self): """ Testing publishing to a channel """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() message = Message.decode_json( { "sender_ID": "User", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }, MessageSchemas.TEST_MESSAGE) result = await nats.send_message("subscribe-test", message) self.assertTrue(result) await nats.disconnect()
async def test_send_data(self): """ Testing whether sending data to a channel works. """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() message = Message.decode_json( { "sender_ID": "User", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }, MessageSchemas.TEST_MESSAGE) result = await nats.send_data("subscribe-test", message) self.assertTrue(result) await nats.disconnect() self.assertEqual(len(nats.data_table), 1)
async def run(): # set the execption handler to None. This makes exception actually stop code execution instead of going unnoticed loop = asyncio.get_running_loop() loop.set_exception_handler(None) # connect to the NATS server self.nats_client = NatsHandler("default", host=self.nats_host, port=self.nats_port, user=self.nats_user, password=self.nats_password, api_host=self.api_host, api_port=self.api_port, loop=asyncio.get_running_loop()) await self.nats_client.connect() # creating logger self._logger = NatsLoggerFactory.get_logger( self.nats_client, self.service_type) # retrieving initial shared_storage await self._load_config() # setting nats sender id self.nats_client.sender_id = self.sender_id # registering callbacks await self._register_callbacks() # execute startup callback if self._startup_callback: if len(signature(self._startup_callback).parameters) == 4: # include kubernetes_client await self._startup_callback(self.nats_client, self.shared_storage, self._logger, self.kubernetes_client) else: await self._startup_callback(self.nats_client, self.shared_storage, self._logger)
async def test_retrieve_data_message(self): """ Testing whether retrieving data from the internal table and creating a message works """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") message = nats.create_message({ "testData": "This is a test" }, MessageSchemas.TEST_MESSAGE) nats.data_table["someID"] = message self.assertEqual(len(nats.data_table), 1) new_message = await nats.retrieve_data_message("someID") self.assertEqual(message.data, new_message.data) self.assertEqual(len(nats.data_table), 0) with self.assertRaises(KeyError): new_message = await nats.retrieve_data_message("someID")
async def test_connect(self): """ Testing whether connecting to the NATS server works. """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") result = await nats.connect() self.assertTrue(result) await nats.disconnect()
async def test_request_response(self): """ Testing whether request response works and whether the callback is called. """ loop = self._asyncioTestLoop loop.set_exception_handler(None) nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() async def callback(msg): print("Got message") loop = self._asyncioTestLoop raw_message = msg msg = Message.decode_raw(msg.data, MessageSchemas.TEST_MESSAGE) print(msg) self.assertEqual(msg.encode_json(), { "sender_ID": "User", "origin_ID": "User", "message_type": "test_message", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }) await nats.send_message(raw_message.reply, msg) await nats.subscribe_callback("response-test", callback) message = Message.decode_json({ "sender_ID": "User", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }, MessageSchemas.TEST_MESSAGE) response = await nats.request_message("response-test", message, MessageSchemas.TEST_MESSAGE, timeout=1) self.assertEqual(response.encode_json(), { "sender_ID": "test", "origin_ID": "User", "message_type": "test_message", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }) result = await nats.unsubscribe_callback("response-test", callback) self.assertTrue(result) await nats.disconnect()
async def test_subscribe(self): """ Testing whether the subscription mechanism works properly and rejects invalid channel names """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() async def callback(msg): print("Got message") result = await nats.subscribe_callback("subscribe-test", callback) self.assertTrue(result) await nats.disconnect()
async def test_receive(self): """ TODO: LOOK at first comment inside test--see if resolved yet Testing whether receiving actually works and whether the callback is called. Could not figure out yet how to assert within the callback. USE WITH CAUTION!! Must check the print output to see if the messages were actually received """ # TODO: Figure out how to assert within the callback and check that it is called in the first place print("Starting") loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() async def callback(msg): print("Got message") print(msg) self.assertEquals( Message.decode_raw(msg.data, MessageSchemas.TEST_MESSAGE).encode_json(), { "sender_ID": "User", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }) print("Is equal") raise ValueError("TEST") await nats.subscribe_callback("subscribe-test", callback) message = Message.decode_json( { "sender_ID": "User", "time_sent": "2020-07-06", "data": { "testData": "This is a test" } }, MessageSchemas.TEST_MESSAGE) await nats.send_message("subscribe-test", message) await nats.disconnect()
async def test_unsubscribe(self): """ Tests unsubscribing from a channel and whether invalid names get rejected """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() async def callback(msg): print("Got message") await nats.subscribe_callback("subscribe-test", callback) result = await nats.unsubscribe_callback("foo", callback) self.assertFalse(result) result = await nats.unsubscribe_callback("subscribe-test", callback) self.assertTrue(result) await nats.disconnect()
async def test_check_status(self): loop = asyncio.get_running_loop() shared_storage = { "simulation": { "clock": True, "logging": False, "czml": False, "config": False }, "cubesats": { "cubesat_1": { "orbits": False, "rl": False, "rl_training": False, "data": False, "agriculture": False } }, "groundstations": { "groundstation_1": { "groundstation": False } }, "iots": { "iot_1": { "iot": False } }, "config_path": "./simulation_config" } nats = NatsHandler("data1", "0.0.0.0", "4222", loop=loop, user="******", password="******") await nats.connect() logger = FakeLogger() await check_status(nats, shared_storage, logger) self.assertTrue(shared_storage["simulation"]["clock"] == False)
async def test_create_message(self): """ Testing whether message creation works. """ loop = self._asyncioTestLoop nats = NatsHandler("test", "0.0.0.0", "4222", loop=loop, user="******", password="******") nats.time_sent = "2020-07-06" nats.sender_id = "1" message = nats.create_message({ "testData": "This is a test" }, MessageSchemas.TEST_MESSAGE) self.assertEqual(message.sender_id, "1") self.assertEqual(message.time_sent, "2020-07-06") self.assertEqual(message.data, { "testData": "This is a test" }) with self.assertRaises(ValidationError): message = nats.create_message({ "testData": "This is a test" }, MessageSchemas.ORBIT_MESSAGE)
async def simulation_timepulse(message: Message, nats_handler: NatsHandler, shared_storage: dict, logger: JsonLogger): nats_handler.time_sent = message.data["time"]
async def heartbeat(message: Message, nats_handler: NatsHandler, shared_storage: dict, logger: JsonLogger) -> Message: return nats_handler.create_message("ALIVE", MessageSchemas.STATUS_MESSAGE)
async def run(): # set the execption handler to None. This makes exception actually stop code execution instead of going unnoticed loop = asyncio.get_running_loop() loop.set_exception_handler(None) # connect to the NATS server self.nats_client = NatsHandler("default", host=self.nats_host, port=self.nats_port, user=self.nats_user, password=self.nats_password, api_host=self.api_host, api_port=self.api_port, loop=asyncio.get_running_loop()) await self.nats_client.connect() # creating logger self._logger = NatsLoggerFactory.get_logger( self.nats_client, self.service_type) # retrieving initial shared_storage if self.config_path is not None: # if a path to a config file is given, initializes from there with open(self.config_path, "r") as f: config = json.load(f) # get own sender_id from config self.sender_id = config["sender_id"] # validate the shared_storage section of the config validate_json(config["shared_storage"], self._schema) self.shared_storage = config["shared_storage"] # write the shared storage and sender ID to Redis self.redis_client.set_shared_storage(self.shared_storage) self.redis_client.set_sender_id(self.sender_id) print( f"Successfully initialized {self.sender_id} {self.service_type} from file" ) else: try: # requesting a config from the config service message = self.nats_client.create_message( self.service_type, MessageSchemas.SERVICE_TYPE_MESSAGE) print( f"Requesting config from config service for node {self.service_type}" ) config_response = await self.nats_client.request_message( "initialize.service", message, MessageSchemas.CONFIG_MESSAGE, timeout=3) print(f"Got config from config service: {config_response}") print(f"Validating ...") # validate the shared storage section of the config validate_json(config_response.data["shared_storage"], self._schema) self.sender_id = config_response.data["sender_id"] self.shared_storage = config_response.data[ "shared_storage"] # write the shared storage and sender ID to Redis self.redis_client.set_sender_id(self.sender_id) self.redis_client.set_shared_storage(self.shared_storage) print( f"Successfully initialized {self.sender_id} {self.service_type} from config service" ) except: try: # try initializing from redis self.sender_id = self.redis_client.get_sender_id() if not self.sender_id: raise ValueError( "Could not get sender id from redis") self.shared_storage = self.redis_client.get_shared_storage( ) print( f"Successfully initialized {self.sender_id} {self.service_type} from redis" ) except Exception as e: raise ValueError( f"Failed to initialize from redis. Aborting. Error: {e}" ) # setting nats sender id self.nats_client.sender_id = self.sender_id # registering callbacks await self._register_callbacks() # execute startup callback if self._startup_callback: await self._startup_callback(self.nats_client, self.shared_storage, self._logger)
class BaseSimulation(): """ Class that provides functionality for registering callbacks on different NATS channels. This is the backbone of the simulation framework. It is responsible for connecting with NATS, Redis, loading configs, and running the python asyncio event loop. This combines NATS and a FastAPI application and runs them in the same event loop, so that extending with more REST endpoints is possible in the future. Currently the REST API is just used for a special callbacks that are meant to send data too large to be transmitted over NATS. """ def __init__(self, service_type: str, schema: dict, config_path: str = None): """ Initializes the base simulation. Registers two default NATS callbacks. One of them is on channel "simulation.timestep" and just updates the current time internal to the simulation's NATS handler object. The other one is registered one "node.status.{SERVICE_TYPE}.{SENDER_ID}" and provides a request response endpoint that can be pinged to see if the service is alive. Args: service_type (string): Name of the service type for this simulation instance schema (dict): Schema used to validate the shared_storage internal to this simulation instance. config_path (string, optional): Path to a config file. If not None, will be used to get the sender_id and initial shared storage of this simulation instance once the BaseSimulation.run method is called """ super().__init__() # initializing nats self.service_type = service_type self._schema = schema self.config_path = config_path self.sender_id = None self.nats_client = None self.redis_client = None self._api = None self._logger = None self.shared_storage = None self._startup_callback = None self._registered_callbacks = [] self._unsubscribe_nats_routes = [] # subscribing to timestep by default to update time in nats_handler @self.subscribe_nats_callback("simulation.timestep", MessageSchemas.TIMESTEP_MESSAGE) async def simulation_timepulse(message: Message, nats_handler: NatsHandler, shared_storage: dict, logger: JsonLogger): nats_handler.time_sent = message.data["time"] # subscribing to node status by default to provide channel to ping and see whether service is alive @self.request_nats_callback(f"node.status.{self.service_type}.", MessageSchemas.STATUS_MESSAGE, append_sender_id=True) async def heartbeat(message: Message, nats_handler: NatsHandler, shared_storage: dict, logger: JsonLogger) -> Message: return nats_handler.create_message("ALIVE", MessageSchemas.STATUS_MESSAGE) async def _register_callbacks(self): """ Private method that activates all the NATS channel subscriptions. """ for route in self._registered_callbacks: await route() async def _stop(self): """ Stops the simulation by unsubscribing the callbacks and disconnecting from the NATS server. """ for unsubscribe_route in self._unsubscribe_nats_routes: await unsubscribe_route() await self.nats_client.disconnect() def startup_callback(self, callback_function: Callable) -> Callable: """ Decorator used to register a callback that will be called at simulation startup in the BaseSimulation.run() method. The callback will be called with arguments nats_handler, shared_storage, logger (in that order). There can only be one startup callback registered, so but subsequent calls will just overwrite the previously set callback. Usage example: @base_simulation_instance.startup_callback async def sample_callback(nats_handler, shared_storage, logger): print("I am the startup callback") Args: callback_function (function): Async callback function that should be executed. Returns: function: The original function that was passed as argument """ self._startup_callback = callback_function return callback_function def subscribe_nats_callback(self, channel: str, message_schema: dict) -> Callable: """ Decorator used to register a callback for a specific NATS channel. The actual registration of the callback with the NATS server happens when BaseSimulation.run() is called. Will call the callback with arguments message, nats_handler, shared_storage, logger (in that order). Usage example: @base_simulation_instance.subscribe_nats_callback("sample.route", MessageSchema) async def sample_callback(msg, nats, shared_storage, logger): print(msg.data) Args: channel (string): Name of the channel that the callback should be registered with. message_schema (dict): Schema to validate incoming messages against Returns: function: Returns decorator function that takes in the actual callback. """ def decorator(callback_function: Callable) -> Callable: # wrap the callback so we can actually subscribe once the simulation runs async def subscription_wrapper(): async def callback_wrapper(msg): # try executing the callback and log if exception occurs try: # decode message msg = Message.decode_raw(msg.data, message_schema) # temporarily copy shared storage, so callback cannot perform invalid changes shared_storage = self.shared_storage.copy() # execute callback await callback_function(msg, self.nats_client, shared_storage, self._logger) # check whether the shared storage is still valid and set it if that is the case if not validate_json(shared_storage, self._schema): raise ValueError( "Invalid change in shared storage") self.shared_storage = shared_storage # buffer the current shared storage in redis self.redis_client.set_shared_storage( self.shared_storage) except Exception as e: await self._logger.error(traceback.format_exc()) # subscribe to the NATS channel await self.nats_client.subscribe_callback( channel, callback_wrapper, orig_callback=callback_function) self._registered_callbacks.append(subscription_wrapper) # create a wrapper so we can unsubscribe at a later time async def unsubscription_wrapper(): return await self.nats_client.unsubscribe_callback( channel, callback_function) self._unsubscribe_nats_routes.append(unsubscription_wrapper) return callback_function return decorator def request_nats_callback(self, channel: str, message_schema: dict, append_sender_id: bool = True) -> Callable: """ Decorator used to register a request callback for a specific NATS channel. This means that any callback registered using this decorator is expected to return an object of type Message which will be sent back via NATS to the sender. That implies that messages sent to the registered channel must be sent using NatsHandler.request_message. The actual registration of the callback with the NATS server happens when BaseSimulation.run() is called. If append_sender_id is True, the sender_id of the simulation object will be appended to the channel name. Will call the callback with arguments message, nats_handler, shared_storage, logger (in that order). Usage example: @base_simulation_instance.request_nats_callback("sample-route", MessageSchema, append_sender_id=True) async def sample_callback(msg, nats, shared_storage, logger): print(msg.data) Args: channel (string): Name of the channel that the callback should be registered with. message_schema (dict): Schema to validate incoming messages against append_sender_id (bool): Indicates whether the sender_id should be appended to the channel name at simulation runtime. Returns: function: Returns decorator function that takes in the actual callback. """ def decorator(callback_function: Callable) -> Callable: # wrap the callback so we can actually subscribe once the simulation runs async def request_wrapper() -> Callable: async def callback_wrapper(msg): # try executing the callback and log if exception occurs try: # decode message and copy raw message to preserve the response channel name raw_message = msg msg = Message.decode_raw(msg.data, message_schema) # temporarily copy shared storage, so callback cannot perform invalid changes shared_storage = self.shared_storage.copy() # execute callback response = await callback_function( msg, self.nats_client, shared_storage, self._logger) # check whether the shared storage is still valid and set it if that is the case if not validate_json(shared_storage, self._schema): raise ValueError( "Invalid change in shared storage") self.shared_storage = shared_storage # buffer the current shared storage in redis self.redis_client.set_shared_storage( self.shared_storage) # send the response via NATS await self.nats_client.send_message( raw_message.reply, response) except Exception as e: await self._logger.error(traceback.format_exc()) # if specified, appending sender_id to channel name sub_channel = channel if append_sender_id: sub_channel += self.sender_id # subscribe to the NATS channel await self.nats_client.subscribe_callback( sub_channel, callback_wrapper, orig_callback=callback_function) self._registered_callbacks.append(request_wrapper) # create a wrapper so we can unsubscribe at a later time async def unsubscription_wrapper(): return await self.nats_client.unsubscribe_callback( channel, callback_function) self._unsubscribe_nats_routes.append(unsubscription_wrapper) return callback_function return decorator def schedule_callback(self, timeout: float) -> Callable: """ Decorator used to register a callback to be executed in a regular time interval. The actual registration of the callback happens when BaseSimulation.run() is called, so the callback will not be active before then. Will call the callback with arguments nats_handler, shared_storage, logger (in that order). Usage example: @base_simulation_instance.schedule_callback(1) async def sample_callback(nats, shared_storage, logger): print("hi"!) Args: timeout (float): Timeout to wait until running the callback again. Returns: Returns decorator function that takes in the actual callback. """ def decorator(callback_function: Callable) -> Callable: # wrap the callback so we can actually subscribe once the simulation runs async def subscription_wrapper(): async def callback_wrapper(): # try executing the callback and log if exception occurs try: while True: # temporarily copy shared storage, so callback cannot perform invalid changes shared_storage = self.shared_storage.copy() # execute callback await callback_function(self.nats_client, shared_storage, self._logger) # check whether the shared storage is still valid and set it if that is the case if not validate_json(self.shared_storage, self._schema): self._schema = self.shared_storage self.shared_storage = shared_storage # buffer the current shared storage in redis self.redis_client.set_shared_storage( self.shared_storage) # timeout until next loop execution await asyncio.sleep(timeout) except Exception as e: await self._logger.error(traceback.format_exc()) # append the task to the event loop loop = asyncio.get_running_loop() loop.create_task(callback_wrapper()) # add the wrapper to the nats_routes list so calling it self._registered_callbacks.append(subscription_wrapper) return callback_function return decorator def subscribe_data_callback(self, channel, message_schema, validator=None): """ Decorator used to register a callback for a specific NATS channel that is used to send data via the REST API. Any broadcasting on channels attached to this callback should be done with NatsHandler.send_data(). Internall this callbacks expects messages of schema API_MESSAGE on the registered channel. It then extracts the host, port, and route for for the GET endpoint to get the data, makes an HTTP request and parses the response into a Message object of schema message_schema. The actual registration of the callback with the NATS server happens when BaseSimulation.run() is called. Will call the callback with arguments message, nats_handler, shared_storage, logger (in that order). Optionally, you can provide a validator as a keyword argument, which should have the same function signature of a callback function and should return True or False depending on whether the message coming from the REST API should be processed. Usage example: @base_simulation_instance.subscribe_data_callback("sample.route", MessageSchema, validator=some_validator_function) async def sample_callback(msg, nats, shared_storage, logger): print(msg.data) Args: channel (string): Name of the channel that the callback should be registered with. message_schema (dict): Schema to validate the incoming API messages against. validator (function, optional): function to check whether a certain NATS API message should be processed. Must have args message, nats_handler, shared_storage, logger (in that order) and return True or False. """ def decorator(callback_function): # subscribe to the given NATS channel but listen for messages of schema API_MESSAGE @self.subscribe_nats_callback(channel, MessageSchemas.API_MESSAGE) async def handle_api_message(message, nats, shared_storage, logger): # if a validator function was given, call it to determine whether the message should be processed if not validator or validator(message, nats, shared_storage, logger): async with aiohttp.ClientSession() as session: # construct the URL to access the data using the info from the API message url = f"http://{message.data['host']}:{message.data['port']}{message.data['route']}/{message.data['data_id']}" async with session.get(url) as response: # check whether GET was successful if response.status == 200: # decode the message and execute the callback msg = Message.decode_json( await response.json(), message_schema) await callback_function( msg, self.nats_client, self.shared_storage, self._logger) return await self._logger.error( json.dumps(await response.json())) return callback_function return decorator def run(self, nats_host="nats", nats_port="4222", nats_user=None, nats_password=None, api_host="127.0.0.1", api_port=8000, redis_host="redis", redis_port=6379, redis_password=None): """ Main entrypoint to starting the simulation. Will register all the callbacks with NATS and REST and start the event loop. Will first attempt to fetch a configuration json containing the sender_id and initial shared_storage from a file, if that fails attempts to get it from redis, if that fails attempts to get it from a different simulation that has a callback registered on channel "initialize.service". Args: nats_host (str, optional): NATS server host. Defaults to "nats". nats_port (str, optional): NATS server port. Defaults to "4222". nats_user (str, optional): NATS user. Defaults to None. nats_password (str, optional): NATS password. Defaults to None. api_host (str, optional): Host under which the own REST API is accesible. Defaults to "127.0.0.1". api_port (int, optional): Port to run the REST API on. Defaults to 8000. redis_host (str, optional): Redis server host. Defaults to "redis". redis_port (int, optional): Redis server port. Defaults to 6379. redis_password (str, optional): Redis server password. Defaults to None. """ self.nats_host = nats_host self.nats_port = nats_port self.nats_user = nats_user self.nats_password = nats_password self.api_host = api_host self.api_port = api_port self.redis_host = redis_host self.redis_port = redis_port self.redis_password = redis_password # creating redis client self.redis_client = RedisHandler(self.service_type, self._schema, host=self.redis_host, port=self.redis_port, password=self.redis_password) # creating api self._api = FastAPI() # registering the data REST endpoint used to query data messages @self._api.get("/data/{data_id}") async def get_data(data_id: str): try: # retrieve the data from the NATS client buffer and return it message = await self.nats_client.retrieve_data_message(data_id) message.sender_id = self.nats_client.sender_id return message.encode_json() except Exception as e: print("Error!") await self._logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail="An error occured") # Since our initialization consists of async functions, registers it as a startup callback that executes # once the event loop starts @self._api.on_event("startup") async def run(): # set the execption handler to None. This makes exception actually stop code execution instead of going unnoticed loop = asyncio.get_running_loop() loop.set_exception_handler(None) # connect to the NATS server self.nats_client = NatsHandler("default", host=self.nats_host, port=self.nats_port, user=self.nats_user, password=self.nats_password, api_host=self.api_host, api_port=self.api_port, loop=asyncio.get_running_loop()) await self.nats_client.connect() # creating logger self._logger = NatsLoggerFactory.get_logger( self.nats_client, self.service_type) # retrieving initial shared_storage if self.config_path is not None: # if a path to a config file is given, initializes from there with open(self.config_path, "r") as f: config = json.load(f) # get own sender_id from config self.sender_id = config["sender_id"] # validate the shared_storage section of the config validate_json(config["shared_storage"], self._schema) self.shared_storage = config["shared_storage"] # write the shared storage and sender ID to Redis self.redis_client.set_shared_storage(self.shared_storage) self.redis_client.set_sender_id(self.sender_id) print( f"Successfully initialized {self.sender_id} {self.service_type} from file" ) else: try: # requesting a config from the config service message = self.nats_client.create_message( self.service_type, MessageSchemas.SERVICE_TYPE_MESSAGE) print( f"Requesting config from config service for node {self.service_type}" ) config_response = await self.nats_client.request_message( "initialize.service", message, MessageSchemas.CONFIG_MESSAGE, timeout=3) print(f"Got config from config service: {config_response}") print(f"Validating ...") # validate the shared storage section of the config validate_json(config_response.data["shared_storage"], self._schema) self.sender_id = config_response.data["sender_id"] self.shared_storage = config_response.data[ "shared_storage"] # write the shared storage and sender ID to Redis self.redis_client.set_sender_id(self.sender_id) self.redis_client.set_shared_storage(self.shared_storage) print( f"Successfully initialized {self.sender_id} {self.service_type} from config service" ) except: try: # try initializing from redis self.sender_id = self.redis_client.get_sender_id() if not self.sender_id: raise ValueError( "Could not get sender id from redis") self.shared_storage = self.redis_client.get_shared_storage( ) print( f"Successfully initialized {self.sender_id} {self.service_type} from redis" ) except Exception as e: raise ValueError( f"Failed to initialize from redis. Aborting. Error: {e}" ) # setting nats sender id self.nats_client.sender_id = self.sender_id # registering callbacks await self._register_callbacks() # execute startup callback if self._startup_callback: await self._startup_callback(self.nats_client, self.shared_storage, self._logger) # registering the nats shutdown with the api server @self._api.on_event("shutdown") async def stop(): await self._stop() # run application uvicorn.run(self._api, host="0.0.0.0", port=self.api_port, debug=False, log_level='error')
def run(self, nats_host="nats", nats_port="4222", nats_user=None, nats_password=None, api_host="127.0.0.1", api_port=8000, redis_host="redis", redis_port=6379, redis_password=None): """ Runs the app with the REST API and NATS client running to train RL Model to create weights for unique simulation. Create a model and agent, and runs them in a loop. Interfaces with the OpenAI Gym environment when it is running, the environment then interfaces with the rest of the simulation through rl service. Saves the trained weights and models to be used in predict mode. Args: nats_host (str, optional): NATS server host. Defaults to "0.0.0.0". nats_port (str, optional): NATS server port. Defaults to "4222". nats_user (str, optional): NATS user. Defaults to "a". nats_password (str, optional): NATS password. Defaults to "b". api_host (str, optional): Host to run the REST API on. Defaults to "127.0.0.1". api_port (int, optional): Port to run the REST API on. Defaults to 8000. redis_host (str, optional): Host where Redis runs. Defaults to "redis". redis_port (int, optional): Port where Redis runs. Defaults to 6379. redis_password (str, optional): Password to acess Redis. Defaults to None. """ # creating NATS client nats = NatsHandler("default", host=nats_host, user=nats_user, password=nats_password) nats.loop.set_exception_handler(None) nats.loop.run_until_complete(nats.connect()) # getting config from config service message = nats.create_message(self.service_type, MessageSchemas.SERVICE_TYPE_MESSAGE) config_response = nats.loop.run_until_complete( nats.request_message("initialize.service", message, MessageSchemas.CONFIG_MESSAGE, timeout=3)) validate_json(config_response.data["shared_storage"], self.schema) sender_id = config_response.data["sender_id"] shared_storage = config_response.data["shared_storage"] nats.sender_id = sender_id ENV_NAME = 'SwarmEnv-v0' # Get the environment and extract the number of actions. env = gym.make(ENV_NAME, nats=nats) #np.random.seed(123) #env.seed(123) nb_actions = env.action_space.n # Next, we build a very simple model. model = Sequential() model.add(Flatten(input_shape=(1, env.observation_space.n))) model.add(Dense(8)) model.add(Activation('relu')) model.add(Dense(8)) model.add(Activation('relu')) model.add(Dense(8)) model.add(Activation('relu')) model.add(Dense(nb_actions)) model.add(Activation('linear')) print(model.summary()) # Finally, we configure and compile our agent. You can use every built-in tensorflow.keras optimizer and # even the metrics! memory = SequentialMemory(limit=1000, window_length=1) policy = EpsGreedyQPolicy() dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=5, target_model_update=1e-2, policy=policy) dqn.compile(Adam(lr=1e-3), metrics=['mae']) dqn.fit(env, nb_steps=500, visualize=True, verbose=2) #Save the weights and model dqn.save_weights( f"{shared_storage['weights_location']}/dqn_{ENV_NAME}_weights.h5f", overwrite=True) model.save(f"{shared_storage['model_location']}/dqn_{ENV_NAME}_model") dqn.test(env, nb_episodes=0, visualize=True)