Ejemplo n.º 1
0
    def test_create_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
        ])
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection,
            create_world_settings={'planet': 'Damogran'},
            join_world_settings={
                'ship_type': 1,
                'player': 'zaphod',
            })
        self.assertIsNotNone(env)
        self.assertEqual('Damogran_01', world_name)

        connection.send.assert_has_calls([
            mock.call(
                text_format.Parse(
                    """settings: {
                key: 'planet', value: { strings: { array: 'Damogran' } }
            }""", dm_env_rpc_pb2.CreateWorldRequest())),
            mock.call(
                text_format.Parse(
                    """world_name: 'Damogran_01'
                settings: { key: 'ship_type', value: { int64s: { array: 1 } } }
                settings: {
                    key: 'player', value: { strings: { array: 'zaphod' } }
                }""", dm_env_rpc_pb2.JoinWorldRequest())),
        ])
Ejemplo n.º 2
0
def main(_):
    pygame.init()

    server, port = _start_server()

    with dm_env_rpc_connection.create_secure_channel_and_connect(
            f'localhost:{port}') as connection:
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection, create_world_settings={}, join_world_settings={})
        with env:
            window_surface = pygame.display.set_mode((800, 600), 0, 32)
            pygame.display.set_caption('Catch Human Agent')

            keep_running = True
            while keep_running:
                requested_action = _ACTION_NOTHING

                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        keep_running = False
                        break
                    elif event.type == pygame.KEYDOWN:
                        if event.key == pygame.K_LEFT:
                            requested_action = _ACTION_LEFT
                        elif event.key == pygame.K_RIGHT:
                            requested_action = _ACTION_RIGHT
                        elif event.key == pygame.K_ESCAPE:
                            keep_running = False
                            break

                actions = {_ACTION_PADDLE: requested_action}
                step_result = env.step(actions)

                board = step_result.observation[_OBSERVATION_BOARD]
                reward = step_result.observation[_OBSERVATION_REWARD]

                _render_window(board, window_surface, reward)

                pygame.display.update()

                pygame.time.wait(_FRAME_DELAY_MS)

        connection.send(
            dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))

    server.stop(None)
Ejemplo n.º 3
0
    def test_create_join_world_with_extension(self):
        class _ExampleExtension:
            def foo(self):
                return 'bar'

        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
        ])

        env, _ = dm_env_adaptor.create_and_join_world(
            connection,
            create_world_settings={},
            join_world_settings={},
            extensions={'extension': _ExampleExtension()})
        self.assertEqual('bar', env.extension.foo())
Ejemplo n.º 4
0
    def test_created_but_failed_to_join_world(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=(
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Damogran_01'),
            error.DmEnvRpcError(status_pb2.Status(message='Failed to Join.')),
            dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(error.DmEnvRpcError, 'Failed to Join'):
            _ = dm_env_adaptor.create_and_join_world(connection,
                                                     create_world_settings={},
                                                     join_world_settings={})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])
Ejemplo n.º 5
0
    def test_create_join_world_with_packed_settings(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
        ])
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection,
            create_world_settings={
                'planet': tensor_utils.pack_tensor('Magrathea')
            },
            join_world_settings={
                'ship_type': tensor_utils.pack_tensor(2),
                'player': tensor_utils.pack_tensor('arthur'),
                'unpacked_setting': [1, 2, 3],
            })
        self.assertIsNotNone(env)
        self.assertEqual('Magrathea_02', world_name)

        connection.send.assert_has_calls([
            mock.call(
                text_format.Parse(
                    """settings: {
                key: 'planet', value: { strings: { array: 'Magrathea' } }
            }""", dm_env_rpc_pb2.CreateWorldRequest())),
            mock.call(
                text_format.Parse(
                    """world_name: 'Magrathea_02'
                settings: { key: 'ship_type', value: { int64s: { array: 2 } } }
                settings: {
                    key: 'player', value: { strings: { array: 'arthur' } }
                }
                settings: {
                    key: 'unpacked_setting', value: {
                      int64s: { array: 1 array: 2 array: 3 }
                      shape: 3
                    }
                }""", dm_env_rpc_pb2.JoinWorldRequest())),
        ])
Ejemplo n.º 6
0
    def test_create_join_world_with_invalid_extension(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='foo'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
            dm_env_rpc_pb2.LeaveWorldResponse(),
            dm_env_rpc_pb2.DestroyWorldRequest()
        ])

        with self.assertRaisesRegex(ValueError,
                                    'DmEnvAdaptor already has attribute'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                extensions={'step': object()})

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(dm_env_rpc_pb2.JoinWorldRequest(world_name='foo')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(dm_env_rpc_pb2.DestroyWorldRequest(world_name='foo'))
        ])
Ejemplo n.º 7
0
    def test_created_and_joined_but_adaptor_failed(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(
            side_effect=(dm_env_rpc_pb2.CreateWorldResponse(
                world_name='Damogran_01'),
                         dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC),
                         dm_env_rpc_pb2.LeaveWorldResponse(),
                         dm_env_rpc_pb2.DestroyWorldResponse()))

        with self.assertRaisesRegex(ValueError, 'Unsupported observations'):
            _ = dm_env_adaptor.create_and_join_world(
                connection,
                create_world_settings={},
                join_world_settings={},
                requested_observations=['invalid_observation'])

        connection.send.assert_has_calls([
            mock.call(dm_env_rpc_pb2.CreateWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.JoinWorldRequest(world_name='Damogran_01')),
            mock.call(dm_env_rpc_pb2.LeaveWorldRequest()),
            mock.call(
                dm_env_rpc_pb2.DestroyWorldRequest(world_name='Damogran_01'))
        ])