Esempio n. 1
0
    def test_reset_seed_setting(self):
        self._world_name = self._connection.send(
            dm_env_rpc_pb2.CreateWorldRequest(
                settings={'seed': tensor_utils.pack_tensor(1234)})).world_name
        self._connection.send(
            dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))

        step_response = self._connection.send(dm_env_rpc_pb2.StepRequest())
        self._connection.send(
            dm_env_rpc_pb2.ResetRequest(
                settings={'seed': tensor_utils.pack_tensor(1234)}))
        self.assertEqual(step_response,
                         self._connection.send(dm_env_rpc_pb2.StepRequest()))
Esempio n. 2
0
  def step(self, actions):
    """Implements dm_env.Environment.step."""
    step_response = self._connection.send(
        dm_env_rpc_pb2.StepRequest(
            requested_observations=self._requested_observation_uids,
            actions=self._action_specs.pack(actions)))

    observations = self._observation_specs.unpack(step_response.observations)

    if (step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING and
        self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING):
      step_type = dm_env.StepType.MID
    elif step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING:
      step_type = dm_env.StepType.FIRST
    elif self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING:
      step_type = dm_env.StepType.LAST
    else:
      raise RuntimeError('Environment transitioned from {} to {}'.format(
          self._last_state, step_response.state))

    self._last_state = step_response.state

    reward = self.reward(
        state=step_response.state,
        step_type=step_type,
        observations=observations)
    discount = self.discount(
        state=step_response.state,
        step_type=step_type,
        observations=observations)
    if not self._is_reward_requested:
      observations.pop(DEFAULT_REWARD_KEY, None)
    if not self._is_discount_requested:
      observations.pop(DEFAULT_DISCOUNT_KEY, None)
    return dm_env.TimeStep(step_type, reward, discount, observations)
def _can_send_message(connection):
    """Returns if `connection` is healthy and able to process requests."""
    try:
        # This should return a response with an error unless the server isn't yet
        # receiving requests.
        connection.send(dm_env_rpc_pb2.StepRequest())
    except error.DmEnvRpcError:
        return True
    except grpc.RpcError:
        return False
    def test_no_nested_actions_step(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(return_value=text_format.Parse(
            """state: RUNNING""", dm_env_rpc_pb2.StepResponse()))
        env = dm_env_adaptor.DmEnvAdaptor(connection,
                                          specs=_SAMPLE_NESTED_SPECS,
                                          requested_observations=[],
                                          nested_tensors=False)
        timestep = env.step({'foo.bar': 123})

        self.assertEqual(dm_env.StepType.FIRST, timestep.step_type)

        connection.send.assert_called_once_with(
            text_format.Parse(
                """actions: { key: 1, value: { int32s: { array: 123 } } }""",
                dm_env_rpc_pb2.StepRequest()))
Esempio n. 5
0
 def test_spec_generate_value_step(self):
     self._connection.send = mock.MagicMock(
         return_value=_SAMPLE_STEP_RESPONSE)
     action_spec = self._env.action_spec()
     actions = {
         name: spec.generate_value()
         for name, spec in action_spec.items()
     }
     self._env.step(actions)
     self._connection.send.assert_called_once_with(
         dm_env_rpc_pb2.StepRequest(
             requested_observations=[1, 2],
             actions={
                 1: tensor_utils.pack_tensor(actions['foo']),
                 2: tensor_utils.pack_tensor(actions['bar'], dtype=np.str_)
             }))
Esempio n. 6
0
  def test_requested_observations(self):
    requested_observations = ['foo']
    filtered_env = dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC,
                                               requested_observations)

    expected_filtered_step_request = dm_env_rpc_pb2.StepRequest(
        requested_observations=[1],
        actions={
            1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8),
            2: tensor_utils.pack_tensor('hello')
        })

    self._connection.send = mock.MagicMock(return_value=_SAMPLE_STEP_RESPONSE)
    filtered_env.step({'foo': 4, 'bar': 'hello'})

    self._connection.send.assert_called_once_with(
        expected_filtered_step_request)
    def test_nested_observations_step(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(return_value=text_format.Parse(
            """state: RUNNING
        observations: { key: 1, value: { int32s: { array: 42 } } }""",
            dm_env_rpc_pb2.StepResponse()))

        expected = {'foo': {'bar': 42}}

        env = dm_env_adaptor.DmEnvAdaptor(connection,
                                          specs=_SAMPLE_NESTED_SPECS,
                                          requested_observations=['foo.bar'])
        timestep = env.step({})
        self.assertEqual(dm_env.StepType.FIRST, timestep.step_type)
        self.assertSameElements(expected, timestep.observation)

        connection.send.assert_called_once_with(
            dm_env_rpc_pb2.StepRequest(requested_observations=[1]))
Esempio n. 8
0
 def test_cannot_step_when_not_joined(self):
     with self.assertRaises(error.DmEnvRpcError):
         self._connection.send(dm_env_rpc_pb2.StepRequest())
Esempio n. 9
0
 def step(self, **kwargs):
     """Sends a StepRequest and returns the StepResponse."""
     return self.connection.send(dm_env_rpc_pb2.StepRequest(**kwargs))
Esempio n. 10
0
"""Tests for dm_env_rpc/dm_env adaptor."""

from absl.testing import absltest
from absl.testing import parameterized
import dm_env
from dm_env import specs
import mock
import numpy as np

from dm_env_rpc.v1 import dm_env_adaptor
from dm_env_rpc.v1 import dm_env_rpc_pb2
from dm_env_rpc.v1 import tensor_utils

_SAMPLE_STEP_REQUEST = dm_env_rpc_pb2.StepRequest(
    requested_observations=[1, 2],
    actions={
        1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8),
        2: tensor_utils.pack_tensor('hello')
    })
_SAMPLE_STEP_RESPONSE = dm_env_rpc_pb2.StepResponse(
    state=dm_env_rpc_pb2.EnvironmentStateType.RUNNING,
    observations={
        1: tensor_utils.pack_tensor(5, dtype=dm_env_rpc_pb2.UINT8),
        2: tensor_utils.pack_tensor('goodbye')
    })
_TERMINATED_STEP_RESPONSE = dm_env_rpc_pb2.StepResponse(
    state=dm_env_rpc_pb2.EnvironmentStateType.TERMINATED,
    observations={
        1: tensor_utils.pack_tensor(5, dtype=dm_env_rpc_pb2.UINT8),
        2: tensor_utils.pack_tensor('goodbye')
    })
_SAMPLE_SPEC = dm_env_rpc_pb2.ActionObservationSpecs(