def __init__(self, **kwargs) -> None: """ Initialize dialogues. :return: None """ Model.__init__(self, **kwargs) BaseGymDialogues.__init__(self, self.context.agent_address)
def __init__(self, address: Address, gym_env: gym.Env): """Initialize a gym channel.""" self.address = address self.gym_env = gym_env self._loop: Optional[AbstractEventLoop] = None self._queue: Optional[asyncio.Queue] = None self._threaded_pool: ThreadPoolExecutor = ThreadPoolExecutor( self.THREAD_POOL_SIZE) self.logger: Union[logging.Logger, logging.LoggerAdapter] = logger self._dialogues = GymDialogues(str(PUBLIC_ID))
def setup(self): """Initialise the class.""" self.env = gym.GoalEnv() configuration = ConnectionConfig( connection_id=GymConnection.connection_id) self.agent_address = "my_address" identity = Identity("name", address=self.agent_address) self.gym_con = GymConnection(gym_env=self.env, identity=identity, configuration=configuration) self.loop = asyncio.get_event_loop() self.gym_address = str(GymConnection.connection_id) self.dialogues = GymDialogues(self.gym_address)
def __init__(self, self_address: Address, **kwargs) -> None: """ Initialize dialogues. :return: None """ def role_from_first_message( # pylint: disable=unused-argument message: Message, receiver_address: Address) -> BaseDialogue.Role: """Infer the role of the agent from an incoming/outgoing first message :param message: an incoming/outgoing first message :param receiver_address: the address of the receiving agent :return: The role of the agent """ return GymDialogue.Role.AGENT BaseGymDialogues.__init__( self, self_address=self_address, role_from_first_message=role_from_first_message, )
def __init__(self, **kwargs: Any) -> None: """ Initialize dialogues. :return: None """ def role_from_first_message( # pylint: disable=unused-argument message: Message, receiver_address: Address) -> BaseDialogue.Role: """Infer the role of the agent from an incoming/outgoing first message :param message: an incoming/outgoing first message :param receiver_address: the address of the receiving agent :return: The role of the agent """ # The gym connection maintains the dialogue on behalf of the environment return GymDialogue.Role.ENVIRONMENT BaseGymDialogues.__init__( self, self_address=str(PUBLIC_ID), role_from_first_message=role_from_first_message, **kwargs, )
class TestGymConnection: """Test the packages/connection/gym/connection.py.""" def setup(self): """Initialise the class.""" self.env = gym.GoalEnv() configuration = ConnectionConfig( connection_id=GymConnection.connection_id) self.agent_address = "my_address" identity = Identity("name", address=self.agent_address) self.gym_con = GymConnection(gym_env=self.env, identity=identity, configuration=configuration) self.loop = asyncio.get_event_loop() self.gym_address = str(GymConnection.connection_id) self.dialogues = GymDialogues(self.gym_address) def teardown(self): """Clean up after tests.""" self.loop.run_until_complete(self.gym_con.disconnect()) @pytest.mark.asyncio async def test_gym_connection_connect(self): """Test the connection None return value after connect().""" assert self.gym_con.channel._queue is None await self.gym_con.channel.connect() assert self.gym_con.channel._queue is not None @pytest.mark.asyncio async def test_decode_envelope_error(self): """Test the decoding error for the envelopes.""" await self.gym_con.connect() envelope = Envelope( to=self.gym_address, sender=self.agent_address, protocol_id=UNKNOWN_PROTOCOL_PUBLIC_ID, message=b"hello", ) with pytest.raises(ValueError): await self.gym_con.send(envelope) @pytest.mark.asyncio async def test_send_connection_error(self): """Test send connection error.""" msg = GymMessage( performative=GymMessage.Performative.RESET, dialogue_reference=self.dialogues. new_self_initiated_dialogue_reference(), ) msg.counterparty = self.gym_address sending_dialogue = self.dialogues.update(msg) assert sending_dialogue is not None envelope = Envelope( to=msg.counterparty, sender=msg.sender, protocol_id=msg.protocol_id, message=msg, ) with pytest.raises(ConnectionError): await self.gym_con.send(envelope) @pytest.mark.asyncio async def test_send_act(self): """Test send act message.""" sending_dialogue = await self.send_reset() last_message = sending_dialogue.last_message assert last_message is not None msg = GymMessage( performative=GymMessage.Performative.ACT, action=GymMessage.AnyObject("any_action"), step_id=1, dialogue_reference=sending_dialogue.dialogue_label. dialogue_reference, message_id=last_message.message_id + 1, target=last_message.message_id, ) msg.counterparty = self.gym_address assert sending_dialogue.update(msg) envelope = Envelope( to=msg.counterparty, sender=msg.sender, protocol_id=msg.protocol_id, message=msg, ) await self.gym_con.connect() observation = 1 reward = 1.0 done = True info = "some info" with patch.object(self.env, "step", return_value=(observation, reward, done, info)) as mock: await self.gym_con.send(envelope) mock.assert_called() response = await asyncio.wait_for(self.gym_con.receive(), timeout=3) response_msg_orig = cast(GymMessage, response.message) response_msg = copy.copy(response_msg_orig) response_msg.is_incoming = True response_msg.counterparty = response_msg_orig.sender response_dialogue = self.dialogues.update(response_msg) assert response_msg.performative == GymMessage.Performative.PERCEPT assert response_msg.step_id == msg.step_id assert response_msg.observation.any == observation assert response_msg.reward == reward assert response_msg.done == done assert response_msg.info.any == info assert sending_dialogue == response_dialogue @pytest.mark.asyncio async def test_send_reset(self): """Test send reset message.""" _ = await self.send_reset() @pytest.mark.asyncio async def test_send_close(self): """Test send close message.""" sending_dialogue = await self.send_reset() last_message = sending_dialogue.last_message assert last_message is not None msg = GymMessage( performative=GymMessage.Performative.CLOSE, dialogue_reference=sending_dialogue.dialogue_label. dialogue_reference, message_id=last_message.message_id + 1, target=last_message.message_id, ) msg.counterparty = self.gym_address assert sending_dialogue.update(msg) envelope = Envelope( to=msg.counterparty, sender=msg.sender, protocol_id=msg.protocol_id, message=msg, ) await self.gym_con.connect() with patch.object(self.env, "close") as mock: await self.gym_con.send(envelope) mock.assert_called() @pytest.mark.asyncio async def test_send_close_negative(self): """Test send close message with invalid reference and message id and target.""" msg = GymMessage( performative=GymMessage.Performative.CLOSE, dialogue_reference=self.dialogues. new_self_initiated_dialogue_reference(), ) msg.counterparty = self.gym_address dialogue = self.dialogues.update(msg) assert dialogue is None msg.sender = self.agent_address envelope = Envelope( to=msg.counterparty, sender=msg.sender, protocol_id=msg.protocol_id, message=msg, ) await self.gym_con.connect() with patch.object(self.gym_con.channel.logger, "warning") as mock_logger: await self.gym_con.send(envelope) mock_logger.assert_any_call( f"Could not create dialogue from message={msg}") async def send_reset(self) -> GymDialogue: """Send a reset.""" msg = GymMessage( performative=GymMessage.Performative.RESET, dialogue_reference=self.dialogues. new_self_initiated_dialogue_reference(), ) msg.counterparty = self.gym_address sending_dialogue = self.dialogues.update(msg) assert sending_dialogue is not None envelope = Envelope( to=msg.counterparty, sender=msg.sender, protocol_id=msg.protocol_id, message=msg, ) await self.gym_con.connect() with patch.object(self.env, "reset") as mock: await self.gym_con.send(envelope) mock.assert_called() response = await asyncio.wait_for(self.gym_con.receive(), timeout=3) response_msg_orig = cast(GymMessage, response.message) response_msg = copy.copy(response_msg_orig) response_msg.is_incoming = True response_msg.counterparty = response_msg_orig.sender response_dialogue = self.dialogues.update(response_msg) assert response_msg.performative == GymMessage.Performative.STATUS assert response_msg.content == {"reset": "success"} assert sending_dialogue == response_dialogue return sending_dialogue @pytest.mark.asyncio async def test_receive_connection_error(self): """Test receive connection error and Cancel Error.""" with pytest.raises(ConnectionError): await self.gym_con.receive() def test_gym_env_load(self): """Load gym env from file.""" curdir = os.getcwd() os.chdir(os.path.join(ROOT_DIR, "examples", "gym_ex")) gym_env_path = "gyms.env.BanditNArmedRandom" configuration = ConnectionConfig( connection_id=GymConnection.connection_id, env=gym_env_path) identity = Identity("name", address=self.agent_address) gym_con = GymConnection(gym_env=None, identity=identity, configuration=configuration) assert gym_con.channel.gym_env is not None os.chdir(curdir)
class GymChannel: """A wrapper of the gym environment.""" THREAD_POOL_SIZE = 3 def __init__(self, address: Address, gym_env: gym.Env): """Initialize a gym channel.""" self.address = address self.gym_env = gym_env self._loop: Optional[AbstractEventLoop] = None self._queue: Optional[asyncio.Queue] = None self._threaded_pool: ThreadPoolExecutor = ThreadPoolExecutor( self.THREAD_POOL_SIZE) self.logger: Union[logging.Logger, logging.LoggerAdapter] = logger self._dialogues = GymDialogues(str(PUBLIC_ID)) def _get_message_and_dialogue( self, envelope: Envelope) -> Tuple[GymMessage, Optional[GymDialogue]]: """ Get a message copy and dialogue related to this message. :param envelope: incoming envelope :return: Tuple[MEssage, Optional[Dialogue]] """ orig_message = cast(GymMessage, envelope.message) message = copy.copy( orig_message ) # TODO: fix; need to copy atm to avoid overwriting "is_incoming" message.is_incoming = True # TODO: fix; should be done by framework message.counterparty = (orig_message.sender ) # TODO: fix; should be done by framework dialogue = cast(GymDialogue, self._dialogues.update(message)) return message, dialogue @property def queue(self) -> asyncio.Queue: """Check queue is set and return queue.""" if self._queue is None: # pragma: nocover raise ValueError("Channel is not connected") return self._queue async def connect(self) -> None: """ Connect an address to the gym. :return: an asynchronous queue, that constitutes the communication channel. """ if self._queue: # pragma: nocover return None self._loop = asyncio.get_event_loop() self._queue = asyncio.Queue() async def send(self, envelope: Envelope) -> None: """ Process the envelopes to the gym. :return: None """ sender = envelope.sender self.logger.debug("Processing message from {}: {}".format( sender, envelope)) if envelope.protocol_id != GymMessage.protocol_id: raise ValueError("This protocol is not valid for gym.") await self.handle_gym_message(envelope) async def _run_in_executor(self, fn, *args): return await self._loop.run_in_executor(self._threaded_pool, fn, *args) async def handle_gym_message(self, envelope: Envelope) -> None: """ Forward a message to gym. :param envelope: the envelope :return: None """ assert isinstance(envelope.message, GymMessage), "Message not of type GymMessage" gym_message, dialogue = self._get_message_and_dialogue(envelope) if dialogue is None: self.logger.warning( "Could not create dialogue from message={}".format( gym_message)) return if gym_message.performative == GymMessage.Performative.ACT: action = gym_message.action.any step_id = gym_message.step_id observation, reward, done, info = await self._run_in_executor( self.gym_env.step, action) msg = GymMessage( performative=GymMessage.Performative.PERCEPT, observation=GymMessage.AnyObject(observation), reward=reward, done=done, info=GymMessage.AnyObject(info), step_id=step_id, target=gym_message.message_id, message_id=gym_message.message_id + 1, dialogue_reference=dialogue.dialogue_label.dialogue_reference, ) elif gym_message.performative == GymMessage.Performative.RESET: await self._run_in_executor(self.gym_env.reset) msg = GymMessage( performative=GymMessage.Performative.STATUS, content={"reset": "success"}, target=gym_message.message_id, message_id=gym_message.message_id + 1, dialogue_reference=dialogue.dialogue_label.dialogue_reference, ) elif gym_message.performative == GymMessage.Performative.CLOSE: await self._run_in_executor(self.gym_env.close) return msg.counterparty = gym_message.counterparty assert dialogue.update(msg), "Error during dialogue update." envelope = Envelope( to=msg.counterparty, sender=msg.sender, protocol_id=msg.protocol_id, message=msg, ) await self._send(envelope) async def _send(self, envelope: Envelope) -> None: """Send a message. :param envelope: the envelope :return: None """ await self.queue.put(envelope) async def disconnect(self) -> None: """ Disconnect. :return: None """ if self._queue is not None: await self._queue.put(None) self._queue = None async def get(self) -> Optional[Envelope]: """Get incoming envelope.""" return await self.queue.get()