Exemple #1
0
 def test_cannot_join_when_no_world_exists(self):
     self._connection.send(
         dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name))
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(
             dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))
     self._connection.send(dm_env_rpc_pb2.CreateWorldRequest())
Exemple #2
0
 def close(self):
     try:
         self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
         self.connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name))
     finally:
         super().close()
Exemple #3
0
 def test_cannot_destroy_world_when_still_joined(self):
     self._connection.send(
         dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(
                 world_name=self._world_name))
Exemple #4
0
def create_and_join_world(
    connection: dm_env_rpc_connection.Connection,
    create_world_settings: Mapping[str, Any],
    join_world_settings: Mapping[str, Any],
    requested_observations: Optional[Iterable[str]] = None,
    extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict()
) -> Tuple[DmEnvAdaptor, str]:
    """Helper function to create and join a world with the provided settings.

  Args:
    connection: An instance of Connection already connected to a dm_env_rpc
      server.
    create_world_settings: Settings used to create the world. Values must be
      packable into a Tensor proto or already packed.
    join_world_settings: Settings used to join the world. Values must be
      packable into a Tensor message.
    requested_observations: Optional set of requested observations.
    extensions: Optional mapping of extension instances to DmEnvAdaptor
      attributes.

  Returns:
    Tuple of DmEnvAdaptor and the created world name.
  """
    world_name = create_world(connection, create_world_settings)
    try:
        return_type = collections.namedtuple('DmEnvAndWorldName',
                                             ['env', 'world_name'])
        return return_type(
            join_world(connection, world_name, join_world_settings,
                       requested_observations, extensions), world_name)
    except (error.DmEnvRpcError, ValueError):
        connection.send(
            dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))
        raise
Exemple #5
0
 def tearDown(self):
     super().tearDown()
     try:
         self.connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name))
     finally:
         self._server_connection.close()
Exemple #6
0
 def tearDown(self):
   self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
   self._connection.send(
       dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name))
   self._server.stop(None)
   self._connection.close()
   self._channel.close()
   super(CatchTestBase, self).tearDown()
Exemple #7
0
 def tearDown(self):
   try:
     if self._world_name:
       self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
       self._connection.send(
           dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name))
   finally:
     self._server_connection.close()
   super().tearDown()
def main(_):
    pygame.init()

    port = portpicker.pick_unused_port()
    server = _start_server(port)

    with grpc.secure_channel('localhost:{}'.format(port),
                             grpc.local_channel_credentials()) as channel:
        grpc.channel_ready_future(channel).result()
        with dm_env_rpc_connection.Connection(channel) as connection:
            response = connection.send(dm_env_rpc_pb2.CreateWorldRequest())
            world_name = response.world_name
            response = connection.send(
                dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name))
            specs = response.specs

            with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env:
                window_surface = pygame.display.set_mode((800, 600), 0, 32)
                pygame.display.set_caption('Catch Human Agent')

                keep_running = True
                while keep_running:
                    requested_action = _ACTION_NOTHING

                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            keep_running = False
                            break
                        elif event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_LEFT:
                                requested_action = _ACTION_LEFT
                            elif event.key == pygame.K_RIGHT:
                                requested_action = _ACTION_RIGHT
                            elif event.key == pygame.K_ESCAPE:
                                keep_running = False
                                break

                    actions = {_ACTION_PADDLE: requested_action}
                    step_result = dm_env.step(actions)

                    board = step_result.observation[_OBSERVATION_BOARD]
                    reward = step_result.observation[_OBSERVATION_REWARD]

                    _render_window(board, window_surface, reward)

                    pygame.display.update()

                    pygame.time.wait(_FRAME_DELAY_MS)

            connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
            connection.send(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
    def test_create_join_world_with_invalid_extension(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
            dm_env_rpc_pb2.LeaveWorldResponse(),
            dm_env_rpc_pb2.DestroyWorldRequest()
        ])

        with self.assertRaisesRegex(ValueError,
                                    'DmEnvAdaptor already has attribute'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                extensions={'step': object()})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo'))
        ])
Exemple #10
0
def main(_):
    pygame.init()

    server, port = _start_server()

    with dm_env_rpc_connection.create_secure_channel_and_connect(
            f'localhost:{port}') as connection:
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection, create_world_settings={}, join_world_settings={})
        with env:
            window_surface = pygame.display.set_mode((800, 600), 0, 32)
            pygame.display.set_caption('Catch Human Agent')

            keep_running = True
            while keep_running:
                requested_action = _ACTION_NOTHING

                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        keep_running = False
                        break
                    elif event.type == pygame.KEYDOWN:
                        if event.key == pygame.K_LEFT:
                            requested_action = _ACTION_LEFT
                        elif event.key == pygame.K_RIGHT:
                            requested_action = _ACTION_RIGHT
                        elif event.key == pygame.K_ESCAPE:
                            keep_running = False
                            break

                actions = {_ACTION_PADDLE: requested_action}
                step_result = env.step(actions)

                board = step_result.observation[_OBSERVATION_BOARD]
                reward = step_result.observation[_OBSERVATION_REWARD]

                _render_window(board, window_surface, reward)

                pygame.display.update()

                pygame.time.wait(_FRAME_DELAY_MS)

        connection.send(
            dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
    def test_created_but_failed_to_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=(
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
            error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')),
            dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'):
            _ = dm_env_adaptor.create_and_join_world(connection,
                                                     create_world_settings={},
                                                     join_world_settings={})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
    def test_created_and_joined_but_adaptor_failed(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(
            side_effect=(dm_env_rpc_pb2.CreateWorldResponse(
                world_name='Damogran_01'),
                         dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
                         dm_env_rpc_pb2.LeaveWorldResponse(),
                         dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(ValueError, 'Unsupported observations'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                requested_observations=['invalid_observation'])

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
Exemple #13
0
 def test_cannot_destroy_world_with_wrong_name(self):
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(world_name='wrong_name'))
from google.protobuf import struct_pb2
from google.rpc import status_pb2
from dm_env_rpc.v1 import connection as dm_env_rpc_connection
from dm_env_rpc.v1 import dm_env_rpc_pb2
from dm_env_rpc.v1 import error
from dm_env_rpc.v1 import tensor_utils

_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest(
    settings={'foo': tensor_utils.pack_tensor('bar')})
_CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse()

_BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest()
_TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(error=status_pb2.Status(
    message='A test error.'))

_INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest(
    world_name='foo')
_INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse(
    leave_world=dm_env_rpc_pb2.LeaveWorldResponse())

_EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request')
_EXTENSION_RESPONSE = struct_pb2.Value(number_value=555)


def _wrap_in_any(proto):
    any_proto = any_pb2.Any()
    any_proto.Pack(proto)
    return any_proto


_REQUEST_RESPONSE_PAIRS = {
    dm_env_rpc_pb2.EnvironmentRequest(create_world=_CREATE_REQUEST).SerializeToString(
Exemple #15
0
 def destroy_world(self, world_name):
     """Destroys the world named `world_name`."""
     if world_name is not None:
         self.connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))