Ejemplo n.º 1
0
 def set_shared_storage(self, shared_storage):
     """
     Sets the shared storage in the redis server for the service.
     """
     validate_json(shared_storage, self.schema)
     shared_storage = json.dumps(shared_storage)
     self.redis_client.set(self.service_type, shared_storage)
     return True
Ejemplo n.º 2
0
    def test_validate_message(self):
        """
        validate_message test
        """
        TEST_MESSAGE = {
            "type": "object",
            "additionalProperties": False,
            "required": ["sender_ID", "time_sent", "data"],
            "properties": {
                "sender_ID": {
                    "type": "string",
                },
                "time_sent": {
                    "type": "string"
                },
                "data": {
                    "type": "object",
                    "additionalProperties": False,
                    "required": ["testData"],
                    "properties": {
                        "testData": {
                            "type": "string"
                        }
                    }
                }
            }
        }

        test_message = {
            "sender_ID": "abc",
            "time_sent": "ajfjejfeni49o4904f",
            "data": {
                "testData": "somedata"
            }
        }
        result = validate_json(test_message, TEST_MESSAGE)
        self.assertTrue(result)

        test_message = {
            "sender_ID": "abc",
            "time_sents": "ajfjejfeni49o4904f",
            "data": {
                "testData": "somedata"
            }
        }
        with self.assertRaises(ValidationError):
            result = validate_json(test_message, TEST_MESSAGE)

        test_message = {
            "sender_ID": 2,
            "time_sent": "ajfjejfeni49o4904f",
            "data": {
                "testData": "somedata"
            }
        }
        with self.assertRaises(ValidationError):
            result = validate_json(test_message, TEST_MESSAGE)
Ejemplo n.º 3
0
 def get_shared_storage(self):
     """
     Gets the shared storage from the redis server for the service.
     Args:
     """
     shared_storage = self.redis_client.get(self.service_type)
     shared_storage = json.loads(shared_storage)
     validate_json(shared_storage, self.schema)
     return shared_storage
Ejemplo n.º 4
0
                async def callback_wrapper(msg):

                    # try executing the callback and log if exception occurs
                    try:
                        # decode message and copy raw message to preserve the response channel name
                        raw_message = msg
                        msg = Message.decode_raw(msg.data, message_schema)

                        # temporarily copy shared storage, so callback cannot perform invalid changes

                        shared_storage = self.shared_storage.copy()

                        # execute callback
                        response = await callback_function(
                            msg, self.nats_client, shared_storage,
                            self._logger)

                        # check whether the shared storage is still valid and set it if that is the case
                        if not validate_json(shared_storage, self._schema):
                            raise ValueError(
                                "Invalid change in shared storage")
                        self.shared_storage = shared_storage

                        # buffer the current shared storage in redis
                        self.redis_client.set_shared_storage(
                            self.shared_storage)

                        # send the response via NATS
                        await self.nats_client.send_message(
                            raw_message.reply, response)
                    except Exception as e:
                        await self._logger.error(traceback.format_exc())
Ejemplo n.º 5
0
                async def callback_wrapper():

                    # try executing the callback and log if exception occurs
                    try:
                        while True:

                            # temporarily copy shared storage, so callback cannot perform invalid changes
                            shared_storage = self.shared_storage.copy()

                            # execute callback
                            await callback_function(self.nats_client,
                                                    shared_storage,
                                                    self._logger)

                            # check whether the shared storage is still valid and set it if that is the case
                            if not validate_json(self.shared_storage,
                                                 self._schema):
                                self._schema = self.shared_storage
                            self.shared_storage = shared_storage

                            # buffer the current shared storage in redis
                            self.redis_client.set_shared_storage(
                                self.shared_storage)

                            # timeout until next loop execution
                            await asyncio.sleep(timeout)
                    except Exception as e:
                        await self._logger.error(traceback.format_exc())
                async def callback_wrapper(msg):

                    # try executing the callback and log if exception occurs
                    try:
                        # decode message
                        msg = Message.decode_raw(msg.data, message_schema)

                        # temporarily copy shared storage, so callback cannot perform invalid changes
                        shared_storage = self.shared_storage.copy()

                        # execute callback
                        if len(signature(callback_function).parameters) == 5:
                            # include kubernetes_client
                            await callback_function(msg, self.nats_client,
                                                    shared_storage,
                                                    self._logger,
                                                    self.kubernetes_client)
                        else:
                            await callback_function(msg, self.nats_client,
                                                    shared_storage,
                                                    self._logger)

                        # check whether the shared storage is still valid and set it if that is the case
                        if not validate_json(shared_storage, self._schema):
                            raise ValueError(
                                "Invalid change in shared storage")
                        self.shared_storage = shared_storage

                        # buffer the current shared storage in redis
                        self.redis_client.set_shared_storage(
                            self.shared_storage)
                    except Exception as e:
                        await self._logger.error(traceback.format_exc())
Ejemplo n.º 7
0
    def decode_json(cls, json_message: dict, schema: dict):
        """
        Decodes a dict into a message object and validate it against a schema.

        Args:
            json_message (dict): String message to be decoded.
            schema (dict): Schema to validate the raw message format against.

        Returns:
            Message: New Message instance populated with the decoded string input.
        """

        if "origin_ID" not in json_message.keys():
            json_message["origin_ID"] = json_message["sender_ID"]
        if "message_type" not in json_message.keys():
            json_message["message_type"] = "unknown"
            if "name" in schema.keys():
                json_message["message_type"] = schema["name"]
        if not validate_json(json_message, schema):
            return None
        return cls(schema,
                   sender_ID=json_message["sender_ID"],
                   origin_ID=json_message["origin_ID"],
                   message_type=json_message["message_type"],
                   time_sent=json_message["time_sent"],
                   data=json_message["data"])
Ejemplo n.º 8
0
    async def _load_config(self):
        """
        Override _load_config to get the configuration from a cluster service that has a callback registered on channel "initialize.service"
        for simulation
        """

        try:
            # requesting a config from the config service
            message = self.nats_client.create_message(
                self.service_type, MessageSchemas.SERVICE_TYPE_MESSAGE)
            print(
                f"Requesting config from config service for node {self.service_type}"
            )
            config_response = await self.nats_client.request_message(
                "initialize.service",
                message,
                MessageSchemas.CONFIG_MESSAGE,
                timeout=3)
            print(f"Got config from config service: {config_response}")
            print(f"Validating ...")

            # validate the shared storage section of the config
            validate_json(config_response.data["shared_storage"], self._schema)
            self.sender_id = config_response.data["sender_id"]
            self.shared_storage = config_response.data["shared_storage"]

            # write the shared storage and sender ID to Redis
            self.redis_client.set_sender_id(self.sender_id)
            self.redis_client.set_shared_storage(self.shared_storage)
            print(
                f"Successfully initialized {self.sender_id} {self.service_type} from config service"
            )
        except:
            try:
                await super()._load_config()
            except Exception as e:
                raise ValueError(f"Failed to load configuration: {e}")
    async def _load_config(self):
        """
        attempt to fetch a configuration json
        containing the sender_id and initial shared_storage from a file, if that fails attempts to get it from redis.
        """
        if self.config_path is not None:

            # if a path to a config file is given, initializes from there
            with open(self.config_path, "r") as f:
                config = json.load(f)

            # get own sender_id from config
            self.sender_id = config["sender_id"]

            # validate the shared_storage section of the config
            validate_json(config["shared_storage"], self._schema)
            self.shared_storage = config["shared_storage"]

            # write the shared storage and sender ID to Redis
            self.redis_client.set_shared_storage(self.shared_storage)
            self.redis_client.set_sender_id(self.sender_id)
            print(
                f"Successfully initialized {self.sender_id} {self.service_type} from file"
            )
        else:
            try:
                # try initializing from redis
                self.sender_id = self.redis_client.get_sender_id()
                if not self.sender_id:
                    raise ValueError("Could not get sender id from redis")
                self.shared_storage = self.redis_client.get_shared_storage()
                print(
                    f"Successfully initialized {self.sender_id} {self.service_type} from redis"
                )
            except Exception as e:
                raise ValueError(
                    f"Failed to initialize from redis. Aborting. Error: {e}")
Ejemplo n.º 10
0
    def encode_json(self) -> dict:
        """
        Creates a dict representation of the message and validates it against its schema.

        Returns:
            dict: Dict representation of the message
        """ 

        json_message = {
            "sender_ID": self.sender_id,
            "origin_ID": self.origin_id,
            "message_type": self.message_type,
            "time_sent": self.time_sent,
            "data": self.data
        }
        if not validate_json(json_message, self.schema):
            return None
        return json_message
Ejemplo n.º 11
0
        async def run():

            # set the execption handler to None. This makes exception actually stop code execution instead of going unnoticed
            loop = asyncio.get_running_loop()
            loop.set_exception_handler(None)

            # connect to the NATS server
            self.nats_client = NatsHandler("default",
                                           host=self.nats_host,
                                           port=self.nats_port,
                                           user=self.nats_user,
                                           password=self.nats_password,
                                           api_host=self.api_host,
                                           api_port=self.api_port,
                                           loop=asyncio.get_running_loop())
            await self.nats_client.connect()

            # creating logger
            self._logger = NatsLoggerFactory.get_logger(
                self.nats_client, self.service_type)

            # retrieving initial shared_storage
            if self.config_path is not None:

                # if a path to a config file is given, initializes from there
                with open(self.config_path, "r") as f:
                    config = json.load(f)

                # get own sender_id from config
                self.sender_id = config["sender_id"]

                # validate the shared_storage section of the config
                validate_json(config["shared_storage"], self._schema)
                self.shared_storage = config["shared_storage"]

                # write the shared storage and sender ID to Redis
                self.redis_client.set_shared_storage(self.shared_storage)
                self.redis_client.set_sender_id(self.sender_id)
                print(
                    f"Successfully initialized {self.sender_id} {self.service_type} from file"
                )
            else:
                try:
                    # requesting a config from the config service
                    message = self.nats_client.create_message(
                        self.service_type, MessageSchemas.SERVICE_TYPE_MESSAGE)
                    print(
                        f"Requesting config from config service for node {self.service_type}"
                    )
                    config_response = await self.nats_client.request_message(
                        "initialize.service",
                        message,
                        MessageSchemas.CONFIG_MESSAGE,
                        timeout=3)
                    print(f"Got config from config service: {config_response}")
                    print(f"Validating ...")

                    # validate the shared storage section of the config
                    validate_json(config_response.data["shared_storage"],
                                  self._schema)
                    self.sender_id = config_response.data["sender_id"]
                    self.shared_storage = config_response.data[
                        "shared_storage"]

                    # write the shared storage and sender ID to Redis
                    self.redis_client.set_sender_id(self.sender_id)
                    self.redis_client.set_shared_storage(self.shared_storage)
                    print(
                        f"Successfully initialized {self.sender_id} {self.service_type} from config service"
                    )
                except:
                    try:
                        # try initializing from redis
                        self.sender_id = self.redis_client.get_sender_id()
                        if not self.sender_id:
                            raise ValueError(
                                "Could not get sender id from redis")
                        self.shared_storage = self.redis_client.get_shared_storage(
                        )
                        print(
                            f"Successfully initialized {self.sender_id} {self.service_type} from redis"
                        )
                    except Exception as e:
                        raise ValueError(
                            f"Failed to initialize from redis. Aborting. Error: {e}"
                        )

            # setting nats sender id
            self.nats_client.sender_id = self.sender_id

            # registering callbacks
            await self._register_callbacks()

            # execute startup callback
            if self._startup_callback:
                await self._startup_callback(self.nats_client,
                                             self.shared_storage, self._logger)
Ejemplo n.º 12
0
    def run(self,
            nats_host="nats",
            nats_port="4222",
            nats_user=None,
            nats_password=None,
            api_host="127.0.0.1",
            api_port=8000,
            redis_host="redis",
            redis_port=6379,
            redis_password=None):
        """
        Runs the app with the REST API and NATS client running to train RL Model to create weights for unique simulation. Create a model and agent,
        and runs them in a loop. Interfaces with the OpenAI Gym environment when it is running, the environment then interfaces with the rest of the
        simulation through rl service. Saves the trained weights and models to be used in predict mode.

        Args:
            nats_host (str, optional): NATS server host. Defaults to "0.0.0.0".
            nats_port (str, optional): NATS server port. Defaults to "4222".
            nats_user (str, optional): NATS user. Defaults to "a".
            nats_password (str, optional): NATS password. Defaults to "b".
            api_host (str, optional): Host to run the REST API on. Defaults to "127.0.0.1".
            api_port (int, optional): Port to run the REST API on. Defaults to 8000.
            redis_host (str, optional): Host where Redis runs. Defaults to "redis".
            redis_port (int, optional): Port where Redis runs. Defaults to 6379.
            redis_password (str, optional): Password to acess Redis. Defaults to None.
        """
        # creating NATS client
        nats = NatsHandler("default",
                           host=nats_host,
                           user=nats_user,
                           password=nats_password)
        nats.loop.set_exception_handler(None)
        nats.loop.run_until_complete(nats.connect())

        # getting config from config service
        message = nats.create_message(self.service_type,
                                      MessageSchemas.SERVICE_TYPE_MESSAGE)
        config_response = nats.loop.run_until_complete(
            nats.request_message("initialize.service",
                                 message,
                                 MessageSchemas.CONFIG_MESSAGE,
                                 timeout=3))

        validate_json(config_response.data["shared_storage"], self.schema)
        sender_id = config_response.data["sender_id"]
        shared_storage = config_response.data["shared_storage"]

        nats.sender_id = sender_id

        ENV_NAME = 'SwarmEnv-v0'
        # Get the environment and extract the number of actions.
        env = gym.make(ENV_NAME, nats=nats)
        #np.random.seed(123)
        #env.seed(123)
        nb_actions = env.action_space.n

        # Next, we build a very simple model.
        model = Sequential()
        model.add(Flatten(input_shape=(1, env.observation_space.n)))
        model.add(Dense(8))
        model.add(Activation('relu'))
        model.add(Dense(8))
        model.add(Activation('relu'))
        model.add(Dense(8))
        model.add(Activation('relu'))
        model.add(Dense(nb_actions))
        model.add(Activation('linear'))
        print(model.summary())
        # Finally, we configure and compile our agent. You can use every built-in tensorflow.keras optimizer and
        # even the metrics!
        memory = SequentialMemory(limit=1000, window_length=1)
        policy = EpsGreedyQPolicy()
        dqn = DQNAgent(model=model,
                       nb_actions=nb_actions,
                       memory=memory,
                       nb_steps_warmup=5,
                       target_model_update=1e-2,
                       policy=policy)
        dqn.compile(Adam(lr=1e-3), metrics=['mae'])

        dqn.fit(env, nb_steps=500, visualize=True, verbose=2)

        #Save the weights and model
        dqn.save_weights(
            f"{shared_storage['weights_location']}/dqn_{ENV_NAME}_weights.h5f",
            overwrite=True)
        model.save(f"{shared_storage['model_location']}/dqn_{ENV_NAME}_model")
        dqn.test(env, nb_episodes=0, visualize=True)