コード例 #1
0
 def test_requested_observation_spec(self):
     requested_observations = ['foo']
     filtered_env = dm_env_adaptor.DmEnvAdaptor(self._connection,
                                                _SAMPLE_SPEC,
                                                requested_observations)
     observation_names = [name for name in filtered_env.observation_spec()]
     self.assertEqual(requested_observations, observation_names)
コード例 #2
0
    def setUp(self):
        self._server_connection = JoinedServerConnection()
        self._connection = self._server_connection.connection
        self.world_name = self._server_connection.world_name

        self._dm_env = dm_env_adaptor.DmEnvAdaptor(
            self._connection, self._server_connection.specs)
        super().setUp()
コード例 #3
0
 def setUp(self):
     super(EnvironmentAutomaticallyRequestsReservedKeywords, self).setUp()
     self._connection = mock.MagicMock()
     self._env = dm_env_adaptor.DmEnvAdaptor(self._connection,
                                             _RESERVED_SPEC,
                                             requested_observations=[])
     self._connection.send = mock.MagicMock(
         return_value=_RESERVED_STEP_RESPONSE)
コード例 #4
0
def main(_):
    pygame.init()

    port = portpicker.pick_unused_port()
    server = _start_server(port)

    with grpc.secure_channel('localhost:{}'.format(port),
                             grpc.local_channel_credentials()) as channel:
        grpc.channel_ready_future(channel).result()
        with dm_env_rpc_connection.Connection(channel) as connection:
            response = connection.send(dm_env_rpc_pb2.CreateWorldRequest())
            world_name = response.world_name
            response = connection.send(
                dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name))
            specs = response.specs

            with dm_env_adaptor.DmEnvAdaptor(connection, specs) as dm_env:
                window_surface = pygame.display.set_mode((800, 600), 0, 32)
                pygame.display.set_caption('Catch Human Agent')

                keep_running = True
                while keep_running:
                    requested_action = _ACTION_NOTHING

                    for event in pygame.event.get():
                        if event.type == pygame.QUIT:
                            keep_running = False
                            break
                        elif event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_LEFT:
                                requested_action = _ACTION_LEFT
                            elif event.key == pygame.K_RIGHT:
                                requested_action = _ACTION_RIGHT
                            elif event.key == pygame.K_ESCAPE:
                                keep_running = False
                                break

                    actions = {_ACTION_PADDLE: requested_action}
                    step_result = dm_env.step(actions)

                    board = step_result.observation[_OBSERVATION_BOARD]
                    reward = step_result.observation[_OBSERVATION_REWARD]

                    _render_window(board, window_surface, reward)

                    pygame.display.update()

                    pygame.time.wait(_FRAME_DELAY_MS)

            connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
            connection.send(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
コード例 #5
0
    def test_extensions(self):
        class _ExampleExtension:
            def foo(self):
                return 'bar'

        env = dm_env_adaptor.DmEnvAdaptor(
            connection=mock.MagicMock(),
            specs=_SAMPLE_SPEC,
            extensions={'extension': _ExampleExtension()})

        self.assertEqual('bar', env.extension.foo())
コード例 #6
0
 def test_explicitly_requesting_reward_and_discount(self):
     env = dm_env_adaptor.DmEnvAdaptor(
         self._connection,
         _RESERVED_SPEC,
         requested_observations=[
             dm_env_adaptor.DEFAULT_REWARD_KEY,
             dm_env_adaptor.DEFAULT_DISCOUNT_KEY
         ])
     expected_observation_spec = {
         dm_env_adaptor.DEFAULT_REWARD_KEY: env.reward_spec(),
         dm_env_adaptor.DEFAULT_DISCOUNT_KEY: env.discount_spec(),
     }
     self.assertEqual(env.observation_spec(), expected_observation_spec)
コード例 #7
0
    def test_no_nested_specs(self):
        env = dm_env_adaptor.DmEnvAdaptor(connection=mock.MagicMock(),
                                          specs=_SAMPLE_NESTED_SPECS,
                                          nested_tensors=False)
        expected_actions = {
            'foo.bar': specs.Array(shape=(), dtype=np.int32, name='foo.bar'),
            'baz': specs.Array(shape=(), dtype=np.str_, name='baz'),
        }
        expected_observations = {
            'foo.bar': specs.Array(shape=(), dtype=np.int32, name='foo.bar'),
            'baz': specs.Array(shape=(), dtype=np.str_, name='baz'),
        }

        self.assertSameElements(expected_actions, env.action_spec())
        self.assertSameElements(expected_observations, env.observation_spec())
コード例 #8
0
    def test_no_nested_actions_step(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(return_value=text_format.Parse(
            """state: RUNNING""", dm_env_rpc_pb2.StepResponse()))
        env = dm_env_adaptor.DmEnvAdaptor(connection,
                                          specs=_SAMPLE_NESTED_SPECS,
                                          requested_observations=[],
                                          nested_tensors=False)
        timestep = env.step({'foo.bar': 123})

        self.assertEqual(dm_env.StepType.FIRST, timestep.step_type)

        connection.send.assert_called_once_with(
            text_format.Parse(
                """actions: { key: 1, value: { int32s: { array: 123 } } }""",
                dm_env_rpc_pb2.StepRequest()))
コード例 #9
0
  def test_requested_observations(self):
    requested_observations = ['foo']
    filtered_env = dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC,
                                               requested_observations)

    expected_filtered_step_request = dm_env_rpc_pb2.StepRequest(
        requested_observations=[1],
        actions={
            1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8),
            2: tensor_utils.pack_tensor('hello')
        })

    self._connection.send = mock.MagicMock(return_value=_SAMPLE_STEP_RESPONSE)
    filtered_env.step({'foo': 4, 'bar': 'hello'})

    self._connection.send.assert_called_once_with(
        expected_filtered_step_request)
コード例 #10
0
    def test_nested_observations_step(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(return_value=text_format.Parse(
            """state: RUNNING
        observations: { key: 1, value: { int32s: { array: 42 } } }""",
            dm_env_rpc_pb2.StepResponse()))

        expected = {'foo': {'bar': 42}}

        env = dm_env_adaptor.DmEnvAdaptor(connection,
                                          specs=_SAMPLE_NESTED_SPECS,
                                          requested_observations=['foo.bar'])
        timestep = env.step({})
        self.assertEqual(dm_env.StepType.FIRST, timestep.step_type)
        self.assertSameElements(expected, timestep.observation)

        connection.send.assert_called_once_with(
            dm_env_rpc_pb2.StepRequest(requested_observations=[1]))
コード例 #11
0
 def setUp(self):
   super(DmEnvAdaptorTests, self).setUp()
   self._connection = mock.MagicMock()
   self._env = dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC)
コード例 #12
0
 def setUp(self):
   super(ReservedKeywordTests, self).setUp()
   self._connection = mock.MagicMock()
   self._env = dm_env_adaptor.DmEnvAdaptor(self._connection, _RESERVED_SPEC)
コード例 #13
0
 def test_invalid_requested_observations(self):
   requested_observations = ['invalid']
   with self.assertRaisesRegex(ValueError,
                               'Unsupported observations requested'):
     dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC,
                                 requested_observations)
コード例 #14
0
 def test_invalid_extension_attr(self):
     with self.assertRaisesRegex(ValueError,
                                 'DmEnvAdaptor already has attribute'):
         dm_env_adaptor.DmEnvAdaptor(connection=mock.MagicMock(),
                                     specs=_SAMPLE_SPEC,
                                     extensions={'_connection': object()})
コード例 #15
0
ファイル: catch_test.py プロジェクト: hengyuan-hu/dm_env_rpc
 def setUp(self):
   super(CatchDmEnvTest, self).setUp()
   response = dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)
   specs = self._connection.send(response).specs
   self._dm_env = dm_env_adaptor.DmEnvAdaptor(self._connection, specs)
   self.object_under_test = self._dm_env