def _connect_to_environment(port, settings):
    """Helper function for connecting to a running dm_hard_eight environment."""
    if settings.level_name not in HARD_EIGHT_TASK_LEVEL_NAMES:
        raise ValueError(
            'Level named "{}" is not supported for dm_hard_eight'.format(
                settings.level_name))
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings={
                'seed': tensor_utils.pack_tensor(settings.seed),
                'episodeId': tensor_utils.pack_tensor(0),
                'levelName': tensor_utils.pack_tensor(settings.level_name),
            })).world_name
    join_world_settings = {
        'width': tensor_utils.pack_tensor(settings.width),
        'height': tensor_utils.pack_tensor(settings.height),
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs
    return _ConnectionDetails(channel=channel,
                              connection=connection,
                              specs=specs)
示例#2
0
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running dm_construction environment."""
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )

    all_settings = {
        key: tensor_utils.pack_tensor(val)
        for key, val in settings.items()
    }

    create_settings = {
        "levelName": all_settings["levelName"],
        "seed": tensor_utils.pack_tensor(0),
        "episodeId": tensor_utils.pack_tensor(0)
    }
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(settings=create_settings)).world_name

    join_settings = all_settings.copy()
    del join_settings["levelName"]
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_settings)).specs

    return channel, connection, specs
示例#3
0
 def test_cannot_join_when_no_world_exists(self):
     self._connection.send(
         dm_env_rpc_pb2.DestroyWorldRequest(world_name=self._world_name))
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(
             dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))
     self._connection.send(dm_env_rpc_pb2.CreateWorldRequest())
示例#4
0
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running Alchemy environment."""
    if settings.level_name not in ALCHEMY_LEVEL_NAMES:
        raise ValueError(
            'Level named "{}" is not a valid dm_alchemy level.'.format(
                settings.level_name))
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings={
                'seed': tensor_utils.pack_tensor(settings.seed),
                'episodeId': tensor_utils.pack_tensor(0),
                'levelName': tensor_utils.pack_tensor(settings.level_name),
                'EventSubscriptionRegex': tensor_utils.pack_tensor(
                    'DeepMind/.*'),
            })).world_name
    join_world_settings = {
        'width': tensor_utils.pack_tensor(settings.width),
        'height': tensor_utils.pack_tensor(settings.height),
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs
    return _ConnectionDetails(channel=channel,
                              connection=connection,
                              specs=specs)
示例#5
0
 def test_cannot_destroy_world_when_still_joined(self):
     self._connection.send(
         dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(
             dm_env_rpc_pb2.DestroyWorldRequest(
                 world_name=self._world_name))
    def test_create_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
        ])
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection,
            create_world_settings={'planet': 'Damogran'},
            join_world_settings={
                'ship_type': 1,
                'player': 'zaphod',
            })
        self.assertIsNotNone(env)
        self.assertEqual('Damogran_01', world_name)

        connection.send.assert_has_calls([
            mock.call(
                text_format.Parse(
                    """settings: {
                key: 'planet', value: { strings: { array: 'Damogran' } }
            }""", dm_env_rpc_pb2.CreateWorldRequest())),
            mock.call(
                text_format.Parse(
                    """world_name: 'Damogran_01'
                settings: { key: 'ship_type', value: { int64s: { array: 1 } } }
                settings: {
                    key: 'player', value: { strings: { array: 'zaphod' } }
                }""", dm_env_rpc_pb2.JoinWorldRequest())),
        ])
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running dm_fast_mapping environment."""
    if settings.level_name not in FAST_MAPPING_TASK_LEVEL_NAMES:
        raise ValueError(
            'Level named "{}" is not a valid dm_fast_mapping level.'.format(
                settings.level_name))
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings={
                'seed': tensor_utils.pack_tensor(settings.seed),
                'episodeId': tensor_utils.pack_tensor(0),
                'levelName': tensor_utils.pack_tensor(settings.level_name),
            })).world_name
    join_world_settings = {
        'width':
        tensor_utils.pack_tensor(settings.width),
        'height':
        tensor_utils.pack_tensor(settings.height),
        'EpisodeLengthSeconds':
        tensor_utils.pack_tensor(settings.episode_length_seconds),
        'ShowReachabilityHUD':
        tensor_utils.pack_tensor(False),
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs
    return _ConnectionDetails(channel=channel,
                              connection=connection,
                              specs=specs)
示例#8
0
    def __init__(self):
        super().__init__()
        response = self.connection.send(dm_env_rpc_pb2.CreateWorldRequest())
        self.world_name = response.world_name

        response = self.connection.send(
            dm_env_rpc_pb2.JoinWorldRequest(world_name=self.world_name))
        self.specs = response.specs
def main(_):
    pygame.init()

    port = portpicker.pick_unused_port()
    server = _start_server(port)

    with grpc.secure_channel('localhost:{}'.format(port),
                             grpc.local_channel_credentials()) as channel:
        grpc.channel_ready_future(channel).result()
        with dm_env_rpc_connection.Connection(channel) as connection:
            response = connection.send(dm_env_rpc_pb2.CreateWorldRequest())
            world_name = response.world_name
            response = connection.send(
                dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name))
            specs = response.specs

            with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env:
                window_surface = pygame.display.set_mode((800, 600), 0, 32)
                pygame.display.set_caption('Catch Human Agent')

                keep_running = True
                while keep_running:
                    requested_action = _ACTION_NOTHING

                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            keep_running = False
                            break
                        elif event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_LEFT:
                                requested_action = _ACTION_LEFT
                            elif event.key == pygame.K_RIGHT:
                                requested_action = _ACTION_RIGHT
                            elif event.key == pygame.K_ESCAPE:
                                keep_running = False
                                break

                    actions = {_ACTION_PADDLE: requested_action}
                    step_result = dm_env.step(actions)

                    board = step_result.observation[_OBSERVATION_BOARD]
                    reward = step_result.observation[_OBSERVATION_REWARD]

                    _render_window(board, window_surface, reward)

                    pygame.display.update()

                    pygame.time.wait(_FRAME_DELAY_MS)

            connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
            connection.send(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
示例#10
0
    def test_reset_seed_setting(self):
        self._world_name = self._connection.send(
            dm_env_rpc_pb2.CreateWorldRequest(
                settings={'seed': tensor_utils.pack_tensor(1234)})).world_name
        self._connection.send(
            dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))

        step_response = self._connection.send(dm_env_rpc_pb2.StepRequest())
        self._connection.send(
            dm_env_rpc_pb2.ResetRequest(
                settings={'seed': tensor_utils.pack_tensor(1234)}))
        self.assertEqual(step_response,
                         self._connection.send(dm_env_rpc_pb2.StepRequest()))
示例#11
0
    def test_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(
            return_value=dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC))

        env = dm_env_adaptor.join_world(connection, 'Damogran_01',
                                        {'player': 'zaphod'})
        self.assertIsNotNone(env)

        connection.send.assert_called_once_with(
            text_format.Parse(
                """world_name: 'Damogran_01'
                settings: {
                    key: 'player', value: { strings: { array: 'zaphod' } }
                }""", dm_env_rpc_pb2.JoinWorldRequest()))
示例#12
0
    def test_created_but_failed_to_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=(
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
            error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')),
            dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'):
            _ = dm_env_adaptor.create_and_join_world(connection,
                                                     create_world_settings={},
                                                     join_world_settings={})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
示例#13
0
def join_world(
    connection: dm_env_rpc_connection.Connection,
    world_name: str,
    join_world_settings: Mapping[str, Any],
    requested_observations: Optional[Iterable[str]] = None,
    extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict()
) -> DmEnvAdaptor:
    """Helper function to join a world with the provided settings.

  Args:
    connection: An instance of Connection already connected to a dm_env_rpc
      server.
    world_name: Name of the world to join.
    join_world_settings: Settings used to join the world. Values must be
      packable into a Tensor message or already packed.
    requested_observations: Optional set of requested observations.
    extensions: Optional mapping of extension instances to DmEnvAdaptor
      attributes.

  Returns:
    Instance of DmEnvAdaptor.
  """

    join_world_settings = {
        key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
              tensor_utils.pack_tensor(value))
        for key, value in join_world_settings.items()
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs

    try:
        return DmEnvAdaptor(connection,
                            specs,
                            requested_observations,
                            extensions=extensions)
    except ValueError:
        connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
        raise
示例#14
0
    def test_create_join_world_with_packed_settings(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
        ])
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection,
            create_world_settings={
                'planet': tensor_utils.pack_tensor('Magrathea')
            },
            join_world_settings={
                'ship_type': tensor_utils.pack_tensor(2),
                'player': tensor_utils.pack_tensor('arthur'),
                'unpacked_setting': [1, 2, 3],
            })
        self.assertIsNotNone(env)
        self.assertEqual('Magrathea_02', world_name)

        connection.send.assert_has_calls([
            mock.call(
                text_format.Parse(
                    """settings: {
                key: 'planet', value: { strings: { array: 'Magrathea' } }
            }""", dm_env_rpc_pb2.CreateWorldRequest())),
            mock.call(
                text_format.Parse(
                    """world_name: 'Magrathea_02'
                settings: { key: 'ship_type', value: { int64s: { array: 2 } } }
                settings: {
                    key: 'player', value: { strings: { array: 'arthur' } }
                }
                settings: {
                    key: 'unpacked_setting', value: {
                      int64s: { array: 1 array: 2 array: 3 }
                      shape: 3
                    }
                }""", dm_env_rpc_pb2.JoinWorldRequest())),
        ])
示例#15
0
    def test_create_join_world_with_invalid_extension(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
            dm_env_rpc_pb2.LeaveWorldResponse(),
            dm_env_rpc_pb2.DestroyWorldRequest()
        ])

        with self.assertRaisesRegex(ValueError,
                                    'DmEnvAdaptor already has attribute'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                extensions={'step': object()})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo'))
        ])
def _make_environment_connection(channel, connection, settings):
  """Helper function for connecting to a running dm_memorytask environment."""
  original_send = connection.send
  connection.send = lambda request: _wrap_send(lambda: original_send(request))
  world_name = connection.send(
      dm_env_rpc_pb2.CreateWorldRequest(
          settings={
              'seed': tensor_utils.pack_tensor(settings.seed),
              'episodeId': tensor_utils.pack_tensor(0),
              'levelName': tensor_utils.pack_tensor(settings.level_name),
          })).world_name
  join_world_settings = {
      'width':
          tensor_utils.pack_tensor(settings.width),
      'height':
          tensor_utils.pack_tensor(settings.height),
      'EpisodeLengthSeconds':
          tensor_utils.pack_tensor(settings.episode_length_seconds),
  }
  specs = connection.send(
      dm_env_rpc_pb2.JoinWorldRequest(
          world_name=world_name, settings=join_world_settings)).specs
  return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
示例#17
0
    def test_created_and_joined_but_adaptor_failed(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(
            side_effect=(dm_env_rpc_pb2.CreateWorldResponse(
                world_name='Damogran_01'),
                         dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
                         dm_env_rpc_pb2.LeaveWorldResponse(),
                         dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(ValueError, 'Unsupported observations'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                requested_observations=['invalid_observation'])

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
示例#18
0
 def setUp(self):
   super(CatchDmEnvTest, self).setUp()
   response = dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)
   specs = self._connection.send(response).specs
   self._dm_env = dm_env_adaptor.DmEnvAdaptor(self._connection, specs)
   self.object_under_test = self._dm_env
示例#19
0
 def test_cannot_join_world_with_wrong_name(self):
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(
             dm_env_rpc_pb2.JoinWorldRequest(world_name='wrong_name'))
示例#20
0
 def test_can_reset_world_when_joined(self):
     self._connection.send(
         dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))
     self._connection.send(dm_env_rpc_pb2.ResetWorldRequest())
示例#21
0
 def join_world(self):
     """Joins a world, returning the specs."""
     response = self.connection.send(
         dm_env_rpc_pb2.JoinWorldRequest(world_name=self.world_name))
     return response.specs
示例#22
0
 def join_world(self, **kwargs):
     """Joins the world and returns the spec."""
     response = self.connection.send(
         dm_env_rpc_pb2.JoinWorldRequest(**kwargs))
     return response.specs
示例#23
0
 def join_world(self):
     """Joins the world to call ResetWorld on."""
     self.connection.send(
         dm_env_rpc_pb2.JoinWorldRequest(
             world_name=self.world_name,
             settings=self.required_join_world_settings))