def test_no_nested_actions_step(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(return_value=text_format.Parse(
            """state: RUNNING""", dm_env_rpc_pb2.StepResponse()))
        env = dm_env_adaptor.DmEnvAdaptor(connection,
                                          specs=_SAMPLE_NESTED_SPECS,
                                          requested_observations=[],
                                          nested_tensors=False)
        timestep = env.step({'foo.bar': 123})

        self.assertEqual(dm_env.StepType.FIRST, timestep.step_type)

        connection.send.assert_called_once_with(
            text_format.Parse(
                """actions: { key: 1, value: { int32s: { array: 123 } } }""",
                dm_env_rpc_pb2.StepRequest()))
    def test_nested_observations_step(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(return_value=text_format.Parse(
            """state: RUNNING
        observations: { key: 1, value: { int32s: { array: 42 } } }""",
            dm_env_rpc_pb2.StepResponse()))

        expected = {'foo': {'bar': 42}}

        env = dm_env_adaptor.DmEnvAdaptor(connection,
                                          specs=_SAMPLE_NESTED_SPECS,
                                          requested_observations=['foo.bar'])
        timestep = env.step({})
        self.assertEqual(dm_env.StepType.FIRST, timestep.step_type)
        self.assertSameElements(expected, timestep.observation)

        connection.send.assert_called_once_with(
            dm_env_rpc_pb2.StepRequest(requested_observations=[1]))
Beispiel #3
0
    def Process(self, request_iterator, context):
        """Processes incoming EnvironmentRequests.

    For each EnvironmentRequest the internal message is extracted and handled.
    The response for that message is then placed in a EnvironmentResponse which
    is returned to the client.

    An error status will be returned if an unknown message type is received or
    if the message is invalid for the current world state.


    Args:
      request_iterator: Message iterator provided by gRPC.
      context: Context provided by gRPC.

    Yields:
      EnvironmentResponse: Response for each incoming EnvironmentRequest.
    """

        env_factory = CatchGameFactory(_INITIAL_SEED)
        env = None
        is_joined = False
        skip_next_frame = False
        action_manager = spec_manager.SpecManager(_action_spec())
        observation_manager = spec_manager.SpecManager(_observation_spec())

        for request in request_iterator:
            environment_response = dm_env_rpc_pb2.EnvironmentResponse()
            try:
                message_type = request.WhichOneof('payload')
                internal_request = getattr(request, message_type)
                _check_message_type(env, is_joined, message_type)

                if message_type == 'create_world':
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.CreateWorldResponse(
                        world_name=_WORLD_NAME)
                elif message_type == 'join_world':
                    if is_joined:
                        raise RuntimeError(
                            f'Tried to join world "{internal_request.world_name}" but '
                            f'already joined to world "{_WORLD_NAME}"')
                    if internal_request.world_name != _WORLD_NAME:
                        raise RuntimeError(
                            f'Tried to join world "{internal_request.world_name}" but the '
                            f'only supported world is "{_WORLD_NAME}"')
                    response = dm_env_rpc_pb2.JoinWorldResponse()
                    for uid, action in _action_spec().items():
                        response.specs.actions[uid].CopyFrom(action)
                    for uid, observation in _observation_spec().items():
                        response.specs.observations[uid].CopyFrom(observation)
                    is_joined = True
                elif message_type == 'step':
                    # We need to skip all actions after creating or resetting the
                    # environment.
                    if skip_next_frame:
                        skip_next_frame = False
                    else:
                        unpacked_actions = action_manager.unpack(
                            internal_request.actions)
                        paddle_action = unpacked_actions.get(
                            _ACTION_PADDLE, _DEFAULT_ACTION)
                        env.update(paddle_action)

                    response = dm_env_rpc_pb2.StepResponse()
                    packed_observations = observation_manager.pack({
                        _OBSERVATION_BOARD:
                        env.draw_board(),
                        _OBSERVATION_REWARD:
                        env.reward()
                    })

                    for requested_observation in internal_request.requested_observations:
                        response.observations[requested_observation].CopyFrom(
                            packed_observations[requested_observation])
                    if env.has_terminated():
                        response.state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED
                    else:
                        response.state = dm_env_rpc_pb2.EnvironmentStateType.RUNNING

                    if env.has_terminated():
                        env = env_factory.new_game()
                        skip_next_frame = True
                elif message_type == 'reset':
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.ResetResponse()
                    for uid, action in _action_spec().items():
                        response.specs.actions[uid].CopyFrom(action)
                    for uid, observation in _observation_spec().items():
                        response.specs.observations[uid].CopyFrom(observation)
                elif message_type == 'reset_world':
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.ResetWorldResponse()
                elif message_type == 'leave_world':
                    is_joined = False
                    response = dm_env_rpc_pb2.LeaveWorldResponse()
                elif message_type == 'destroy_world':
                    if internal_request.world_name != _WORLD_NAME:
                        raise RuntimeError(
                            'Tried to destroy world "{}" but we only support world "{}"'
                            .format(internal_request.world_name, _WORLD_NAME))
                    env = None
                    response = dm_env_rpc_pb2.DestroyWorldResponse()
                else:
                    raise RuntimeError(
                        'Unhandled message: {}'.format(message_type))
                getattr(environment_response, message_type).CopyFrom(response)
            except Exception as e:  # pylint: disable=broad-except
                environment_response.error.CopyFrom(
                    status_pb2.Status(message=str(e)))

            yield environment_response
Beispiel #4
0
import mock
import numpy as np

from dm_env_rpc.v1 import dm_env_adaptor
from dm_env_rpc.v1 import dm_env_rpc_pb2
from dm_env_rpc.v1 import tensor_utils

_SAMPLE_STEP_REQUEST = dm_env_rpc_pb2.StepRequest(
    requested_observations=[1, 2],
    actions={
        1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8),
        2: tensor_utils.pack_tensor('hello')
    })
_SAMPLE_STEP_RESPONSE = dm_env_rpc_pb2.StepResponse(
    state=dm_env_rpc_pb2.EnvironmentStateType.RUNNING,
    observations={
        1: tensor_utils.pack_tensor(5, dtype=dm_env_rpc_pb2.UINT8),
        2: tensor_utils.pack_tensor('goodbye')
    })
_TERMINATED_STEP_RESPONSE = dm_env_rpc_pb2.StepResponse(
    state=dm_env_rpc_pb2.EnvironmentStateType.TERMINATED,
    observations={
        1: tensor_utils.pack_tensor(5, dtype=dm_env_rpc_pb2.UINT8),
        2: tensor_utils.pack_tensor('goodbye')
    })
_SAMPLE_SPEC = dm_env_rpc_pb2.ActionObservationSpecs(
    actions={
        1: dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.UINT8, name='foo'),
        2: dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.STRING, name='bar')
    },
    observations={
        1: dm_env_rpc_pb2.TensorSpec(dtype=dm_env_rpc_pb2.UINT8, name='foo'),