Exemple #1
0
    def __init__(self, **kwargs) -> None:
        """
        Initialize dialogues.

        :return: None
        """
        Model.__init__(self, **kwargs)
        BaseGymDialogues.__init__(self, self.context.agent_address)
 def __init__(self, address: Address, gym_env: gym.Env):
     """Initialize a gym channel."""
     self.address = address
     self.gym_env = gym_env
     self._loop: Optional[AbstractEventLoop] = None
     self._queue: Optional[asyncio.Queue] = None
     self._threaded_pool: ThreadPoolExecutor = ThreadPoolExecutor(
         self.THREAD_POOL_SIZE)
     self.logger: Union[logging.Logger, logging.LoggerAdapter] = logger
     self._dialogues = GymDialogues(str(PUBLIC_ID))
Exemple #3
0
 def setup(self):
     """Initialise the class."""
     self.env = gym.GoalEnv()
     configuration = ConnectionConfig(
         connection_id=GymConnection.connection_id)
     self.agent_address = "my_address"
     identity = Identity("name", address=self.agent_address)
     self.gym_con = GymConnection(gym_env=self.env,
                                  identity=identity,
                                  configuration=configuration)
     self.loop = asyncio.get_event_loop()
     self.gym_address = str(GymConnection.connection_id)
     self.dialogues = GymDialogues(self.gym_address)
Exemple #4
0
    def __init__(self, self_address: Address, **kwargs) -> None:
        """
        Initialize dialogues.

        :return: None
        """
        def role_from_first_message(  # pylint: disable=unused-argument
                message: Message,
                receiver_address: Address) -> BaseDialogue.Role:
            """Infer the role of the agent from an incoming/outgoing first message

            :param message: an incoming/outgoing first message
            :param receiver_address: the address of the receiving agent
            :return: The role of the agent
            """
            return GymDialogue.Role.AGENT

        BaseGymDialogues.__init__(
            self,
            self_address=self_address,
            role_from_first_message=role_from_first_message,
        )
Exemple #5
0
    def __init__(self, **kwargs: Any) -> None:
        """
        Initialize dialogues.

        :return: None
        """
        def role_from_first_message(  # pylint: disable=unused-argument
                message: Message,
                receiver_address: Address) -> BaseDialogue.Role:
            """Infer the role of the agent from an incoming/outgoing first message

            :param message: an incoming/outgoing first message
            :param receiver_address: the address of the receiving agent
            :return: The role of the agent
            """
            # The gym connection maintains the dialogue on behalf of the environment
            return GymDialogue.Role.ENVIRONMENT

        BaseGymDialogues.__init__(
            self,
            self_address=str(PUBLIC_ID),
            role_from_first_message=role_from_first_message,
            **kwargs,
        )
Exemple #6
0
class TestGymConnection:
    """Test the packages/connection/gym/connection.py."""
    def setup(self):
        """Initialise the class."""
        self.env = gym.GoalEnv()
        configuration = ConnectionConfig(
            connection_id=GymConnection.connection_id)
        self.agent_address = "my_address"
        identity = Identity("name", address=self.agent_address)
        self.gym_con = GymConnection(gym_env=self.env,
                                     identity=identity,
                                     configuration=configuration)
        self.loop = asyncio.get_event_loop()
        self.gym_address = str(GymConnection.connection_id)
        self.dialogues = GymDialogues(self.gym_address)

    def teardown(self):
        """Clean up after tests."""
        self.loop.run_until_complete(self.gym_con.disconnect())

    @pytest.mark.asyncio
    async def test_gym_connection_connect(self):
        """Test the connection None return value after connect()."""
        assert self.gym_con.channel._queue is None
        await self.gym_con.channel.connect()
        assert self.gym_con.channel._queue is not None

    @pytest.mark.asyncio
    async def test_decode_envelope_error(self):
        """Test the decoding error for the envelopes."""
        await self.gym_con.connect()
        envelope = Envelope(
            to=self.gym_address,
            sender=self.agent_address,
            protocol_id=UNKNOWN_PROTOCOL_PUBLIC_ID,
            message=b"hello",
        )

        with pytest.raises(ValueError):
            await self.gym_con.send(envelope)

    @pytest.mark.asyncio
    async def test_send_connection_error(self):
        """Test send connection error."""
        msg = GymMessage(
            performative=GymMessage.Performative.RESET,
            dialogue_reference=self.dialogues.
            new_self_initiated_dialogue_reference(),
        )
        msg.counterparty = self.gym_address
        sending_dialogue = self.dialogues.update(msg)
        assert sending_dialogue is not None
        envelope = Envelope(
            to=msg.counterparty,
            sender=msg.sender,
            protocol_id=msg.protocol_id,
            message=msg,
        )

        with pytest.raises(ConnectionError):
            await self.gym_con.send(envelope)

    @pytest.mark.asyncio
    async def test_send_act(self):
        """Test send act message."""
        sending_dialogue = await self.send_reset()
        last_message = sending_dialogue.last_message
        assert last_message is not None
        msg = GymMessage(
            performative=GymMessage.Performative.ACT,
            action=GymMessage.AnyObject("any_action"),
            step_id=1,
            dialogue_reference=sending_dialogue.dialogue_label.
            dialogue_reference,
            message_id=last_message.message_id + 1,
            target=last_message.message_id,
        )
        msg.counterparty = self.gym_address
        assert sending_dialogue.update(msg)
        envelope = Envelope(
            to=msg.counterparty,
            sender=msg.sender,
            protocol_id=msg.protocol_id,
            message=msg,
        )
        await self.gym_con.connect()

        observation = 1
        reward = 1.0
        done = True
        info = "some info"
        with patch.object(self.env,
                          "step",
                          return_value=(observation, reward, done,
                                        info)) as mock:
            await self.gym_con.send(envelope)
            mock.assert_called()

        response = await asyncio.wait_for(self.gym_con.receive(), timeout=3)
        response_msg_orig = cast(GymMessage, response.message)
        response_msg = copy.copy(response_msg_orig)
        response_msg.is_incoming = True
        response_msg.counterparty = response_msg_orig.sender
        response_dialogue = self.dialogues.update(response_msg)

        assert response_msg.performative == GymMessage.Performative.PERCEPT
        assert response_msg.step_id == msg.step_id
        assert response_msg.observation.any == observation
        assert response_msg.reward == reward
        assert response_msg.done == done
        assert response_msg.info.any == info
        assert sending_dialogue == response_dialogue

    @pytest.mark.asyncio
    async def test_send_reset(self):
        """Test send reset message."""
        _ = await self.send_reset()

    @pytest.mark.asyncio
    async def test_send_close(self):
        """Test send close message."""
        sending_dialogue = await self.send_reset()
        last_message = sending_dialogue.last_message
        assert last_message is not None
        msg = GymMessage(
            performative=GymMessage.Performative.CLOSE,
            dialogue_reference=sending_dialogue.dialogue_label.
            dialogue_reference,
            message_id=last_message.message_id + 1,
            target=last_message.message_id,
        )
        msg.counterparty = self.gym_address
        assert sending_dialogue.update(msg)
        envelope = Envelope(
            to=msg.counterparty,
            sender=msg.sender,
            protocol_id=msg.protocol_id,
            message=msg,
        )
        await self.gym_con.connect()

        with patch.object(self.env, "close") as mock:
            await self.gym_con.send(envelope)
            mock.assert_called()

    @pytest.mark.asyncio
    async def test_send_close_negative(self):
        """Test send close message with invalid reference and message id and target."""
        msg = GymMessage(
            performative=GymMessage.Performative.CLOSE,
            dialogue_reference=self.dialogues.
            new_self_initiated_dialogue_reference(),
        )
        msg.counterparty = self.gym_address
        dialogue = self.dialogues.update(msg)
        assert dialogue is None
        msg.sender = self.agent_address
        envelope = Envelope(
            to=msg.counterparty,
            sender=msg.sender,
            protocol_id=msg.protocol_id,
            message=msg,
        )
        await self.gym_con.connect()

        with patch.object(self.gym_con.channel.logger,
                          "warning") as mock_logger:
            await self.gym_con.send(envelope)
            mock_logger.assert_any_call(
                f"Could not create dialogue from message={msg}")

    async def send_reset(self) -> GymDialogue:
        """Send a reset."""
        msg = GymMessage(
            performative=GymMessage.Performative.RESET,
            dialogue_reference=self.dialogues.
            new_self_initiated_dialogue_reference(),
        )
        msg.counterparty = self.gym_address
        sending_dialogue = self.dialogues.update(msg)
        assert sending_dialogue is not None
        envelope = Envelope(
            to=msg.counterparty,
            sender=msg.sender,
            protocol_id=msg.protocol_id,
            message=msg,
        )
        await self.gym_con.connect()

        with patch.object(self.env, "reset") as mock:
            await self.gym_con.send(envelope)
            mock.assert_called()

        response = await asyncio.wait_for(self.gym_con.receive(), timeout=3)
        response_msg_orig = cast(GymMessage, response.message)
        response_msg = copy.copy(response_msg_orig)
        response_msg.is_incoming = True
        response_msg.counterparty = response_msg_orig.sender
        response_dialogue = self.dialogues.update(response_msg)

        assert response_msg.performative == GymMessage.Performative.STATUS
        assert response_msg.content == {"reset": "success"}
        assert sending_dialogue == response_dialogue
        return sending_dialogue

    @pytest.mark.asyncio
    async def test_receive_connection_error(self):
        """Test receive connection error and Cancel Error."""
        with pytest.raises(ConnectionError):
            await self.gym_con.receive()

    def test_gym_env_load(self):
        """Load gym env from file."""
        curdir = os.getcwd()
        os.chdir(os.path.join(ROOT_DIR, "examples", "gym_ex"))
        gym_env_path = "gyms.env.BanditNArmedRandom"
        configuration = ConnectionConfig(
            connection_id=GymConnection.connection_id, env=gym_env_path)
        identity = Identity("name", address=self.agent_address)
        gym_con = GymConnection(gym_env=None,
                                identity=identity,
                                configuration=configuration)
        assert gym_con.channel.gym_env is not None
        os.chdir(curdir)
class GymChannel:
    """A wrapper of the gym environment."""

    THREAD_POOL_SIZE = 3

    def __init__(self, address: Address, gym_env: gym.Env):
        """Initialize a gym channel."""
        self.address = address
        self.gym_env = gym_env
        self._loop: Optional[AbstractEventLoop] = None
        self._queue: Optional[asyncio.Queue] = None
        self._threaded_pool: ThreadPoolExecutor = ThreadPoolExecutor(
            self.THREAD_POOL_SIZE)
        self.logger: Union[logging.Logger, logging.LoggerAdapter] = logger
        self._dialogues = GymDialogues(str(PUBLIC_ID))

    def _get_message_and_dialogue(
            self,
            envelope: Envelope) -> Tuple[GymMessage, Optional[GymDialogue]]:
        """
        Get a message copy and dialogue related to this message.

        :param envelope: incoming envelope

        :return: Tuple[MEssage, Optional[Dialogue]]
        """
        orig_message = cast(GymMessage, envelope.message)
        message = copy.copy(
            orig_message
        )  # TODO: fix; need to copy atm to avoid overwriting "is_incoming"
        message.is_incoming = True  # TODO: fix; should be done by framework
        message.counterparty = (orig_message.sender
                                )  # TODO: fix; should be done by framework
        dialogue = cast(GymDialogue, self._dialogues.update(message))
        return message, dialogue

    @property
    def queue(self) -> asyncio.Queue:
        """Check queue is set and return queue."""
        if self._queue is None:  # pragma: nocover
            raise ValueError("Channel is not connected")
        return self._queue

    async def connect(self) -> None:
        """
        Connect an address to the gym.

        :return: an asynchronous queue, that constitutes the communication channel.
        """
        if self._queue:  # pragma: nocover
            return None
        self._loop = asyncio.get_event_loop()
        self._queue = asyncio.Queue()

    async def send(self, envelope: Envelope) -> None:
        """
        Process the envelopes to the gym.

        :return: None
        """
        sender = envelope.sender
        self.logger.debug("Processing message from {}: {}".format(
            sender, envelope))
        if envelope.protocol_id != GymMessage.protocol_id:
            raise ValueError("This protocol is not valid for gym.")
        await self.handle_gym_message(envelope)

    async def _run_in_executor(self, fn, *args):
        return await self._loop.run_in_executor(self._threaded_pool, fn, *args)

    async def handle_gym_message(self, envelope: Envelope) -> None:
        """
        Forward a message to gym.

        :param envelope: the envelope
        :return: None
        """
        assert isinstance(envelope.message,
                          GymMessage), "Message not of type GymMessage"
        gym_message, dialogue = self._get_message_and_dialogue(envelope)

        if dialogue is None:
            self.logger.warning(
                "Could not create dialogue from message={}".format(
                    gym_message))
            return

        if gym_message.performative == GymMessage.Performative.ACT:
            action = gym_message.action.any
            step_id = gym_message.step_id

            observation, reward, done, info = await self._run_in_executor(
                self.gym_env.step, action)
            msg = GymMessage(
                performative=GymMessage.Performative.PERCEPT,
                observation=GymMessage.AnyObject(observation),
                reward=reward,
                done=done,
                info=GymMessage.AnyObject(info),
                step_id=step_id,
                target=gym_message.message_id,
                message_id=gym_message.message_id + 1,
                dialogue_reference=dialogue.dialogue_label.dialogue_reference,
            )
        elif gym_message.performative == GymMessage.Performative.RESET:
            await self._run_in_executor(self.gym_env.reset)
            msg = GymMessage(
                performative=GymMessage.Performative.STATUS,
                content={"reset": "success"},
                target=gym_message.message_id,
                message_id=gym_message.message_id + 1,
                dialogue_reference=dialogue.dialogue_label.dialogue_reference,
            )
        elif gym_message.performative == GymMessage.Performative.CLOSE:
            await self._run_in_executor(self.gym_env.close)
            return
        msg.counterparty = gym_message.counterparty
        assert dialogue.update(msg), "Error during dialogue update."
        envelope = Envelope(
            to=msg.counterparty,
            sender=msg.sender,
            protocol_id=msg.protocol_id,
            message=msg,
        )
        await self._send(envelope)

    async def _send(self, envelope: Envelope) -> None:
        """Send a message.

        :param envelope: the envelope
        :return: None
        """
        await self.queue.put(envelope)

    async def disconnect(self) -> None:
        """
        Disconnect.

        :return: None
        """
        if self._queue is not None:
            await self._queue.put(None)
            self._queue = None

    async def get(self) -> Optional[Envelope]:
        """Get incoming envelope."""
        return await self.queue.get()