예제 #1
0
 def close(self):
     try:
         self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
         self.connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name))
     finally:
         super().close()
예제 #2
0
 def tearDown(self):
   self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
   self._connection.send(
       dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name))
   self._server.stop(None)
   self._connection.close()
   self._channel.close()
   super(CatchTestBase, self).tearDown()
예제 #3
0
 def tearDown(self):
     super().tearDown()
     try:
         self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
         self.connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(world_name=self.world_name))
     finally:
         self._server_connection.close()
예제 #4
0
def main(_):
    pygame.init()

    port = portpicker.pick_unused_port()
    server = _start_server(port)

    with grpc.secure_channel('localhost:{}'.format(port),
                             grpc.local_channel_credentials()) as channel:
        grpc.channel_ready_future(channel).result()
        with dm_env_rpc_connection.Connection(channel) as connection:
            response = connection.send(dm_env_rpc_pb2.CreateWorldRequest())
            world_name = response.world_name
            response = connection.send(
                dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name))
            specs = response.specs

            with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env:
                window_surface = pygame.display.set_mode((800, 600), 0, 32)
                pygame.display.set_caption('Catch Human Agent')

                keep_running = True
                while keep_running:
                    requested_action = _ACTION_NOTHING

                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            keep_running = False
                            break
                        elif event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_LEFT:
                                requested_action = _ACTION_LEFT
                            elif event.key == pygame.K_RIGHT:
                                requested_action = _ACTION_RIGHT
                            elif event.key == pygame.K_ESCAPE:
                                keep_running = False
                                break

                    actions = {_ACTION_PADDLE: requested_action}
                    step_result = dm_env.step(actions)

                    board = step_result.observation[_OBSERVATION_BOARD]
                    reward = step_result.observation[_OBSERVATION_REWARD]

                    _render_window(board, window_surface, reward)

                    pygame.display.update()

                    pygame.time.wait(_FRAME_DELAY_MS)

            connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
            connection.send(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
예제 #5
0
    def test_created_but_failed_to_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=(
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
            error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')),
            dm_env_rpc_pb2.LeaveWorldResponse(),
            dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'):
            _ = dm_env_adaptor.create_and_join_world(connection,
                                                     create_world_settings={},
                                                     join_world_settings={})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
예제 #6
0
def join_world(
    connection: dm_env_rpc_connection.Connection,
    world_name: str,
    join_world_settings: Mapping[str, Any],
    requested_observations: Optional[Iterable[str]] = None,
    extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict()
) -> DmEnvAdaptor:
    """Helper function to join a world with the provided settings.

  Args:
    connection: An instance of Connection already connected to a dm_env_rpc
      server.
    world_name: Name of the world to join.
    join_world_settings: Settings used to join the world. Values must be
      packable into a Tensor message or already packed.
    requested_observations: Optional set of requested observations.
    extensions: Optional mapping of extension instances to DmEnvAdaptor
      attributes.

  Returns:
    Instance of DmEnvAdaptor.
  """

    join_world_settings = {
        key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
              tensor_utils.pack_tensor(value))
        for key, value in join_world_settings.items()
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs

    try:
        return DmEnvAdaptor(connection,
                            specs,
                            requested_observations,
                            extensions=extensions)
    except ValueError:
        connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
        raise
예제 #7
0
    def test_create_join_world_with_invalid_extension(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
            dm_env_rpc_pb2.LeaveWorldResponse(),
            dm_env_rpc_pb2.DestroyWorldRequest()
        ])

        with self.assertRaisesRegex(ValueError,
                                    'DmEnvAdaptor already has attribute'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                extensions={'step': object()})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo'))
        ])
예제 #8
0
    def test_created_and_joined_but_adaptor_failed(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(
            side_effect=(dm_env_rpc_pb2.CreateWorldResponse(
                world_name='Damogran_01'),
                         dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
                         dm_env_rpc_pb2.LeaveWorldResponse(),
                         dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(ValueError, 'Unsupported observations'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                requested_observations=['invalid_observation'])

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
예제 #9
0
 def leave_world(self):
     """Leaves the world."""
     self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
예제 #10
0
 def close(self):
     """Implements dm_env.Environment.close."""
     # Leaves the world if we were joined.  If not, this will be a no-op anyway.
     self._connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
     self._connection = None
예제 #11
0
 def test_close_leaves_world(self):
   self._connection.send = mock.MagicMock(
       return_value=dm_env_rpc_pb2.LeaveWorldResponse())
   self._env.close()
   self._connection.send.assert_called_once_with(
       dm_env_rpc_pb2.LeaveWorldRequest())
예제 #12
0
 def leave_world(self):
     """Leaves currently joined world, if any."""
     self.connection.send(dm_env_rpc_pb2.LeaveWorldRequest())