def test_reset_seed_setting(self): self._world_name = self._connection.send( dm_env_rpc_pb2.CreateWorldRequest( settings={'seed': tensor_utils.pack_tensor(1234)})).world_name self._connection.send( dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name)) step_response = self._connection.send(dm_env_rpc_pb2.StepRequest()) self._connection.send( dm_env_rpc_pb2.ResetRequest( settings={'seed': tensor_utils.pack_tensor(1234)})) self.assertEqual(step_response, self._connection.send(dm_env_rpc_pb2.StepRequest()))
def step(self, actions): """Implements dm_env.Environment.step.""" step_response = self._connection.send( dm_env_rpc_pb2.StepRequest( requested_observations=self._requested_observation_uids, actions=self._action_specs.pack(actions))) observations = self._observation_specs.unpack(step_response.observations) if (step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING and self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING): step_type = dm_env.StepType.MID elif step_response.state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING: step_type = dm_env.StepType.FIRST elif self._last_state == dm_env_rpc_pb2.EnvironmentStateType.RUNNING: step_type = dm_env.StepType.LAST else: raise RuntimeError('Environment transitioned from {} to {}'.format( self._last_state, step_response.state)) self._last_state = step_response.state reward = self.reward( state=step_response.state, step_type=step_type, observations=observations) discount = self.discount( state=step_response.state, step_type=step_type, observations=observations) if not self._is_reward_requested: observations.pop(DEFAULT_REWARD_KEY, None) if not self._is_discount_requested: observations.pop(DEFAULT_DISCOUNT_KEY, None) return dm_env.TimeStep(step_type, reward, discount, observations)
def _can_send_message(connection): """Returns if `connection` is healthy and able to process requests.""" try: # This should return a response with an error unless the server isn't yet # receiving requests. connection.send(dm_env_rpc_pb2.StepRequest()) except error.DmEnvRpcError: return True except grpc.RpcError: return False
def test_no_nested_actions_step(self): connection = mock.MagicMock() connection.send = mock.MagicMock(return_value=text_format.Parse( """state: RUNNING""", dm_env_rpc_pb2.StepResponse())) env = dm_env_adaptor.DmEnvAdaptor(connection, specs=_SAMPLE_NESTED_SPECS, requested_observations=[], nested_tensors=False) timestep = env.step({'foo.bar': 123}) self.assertEqual(dm_env.StepType.FIRST, timestep.step_type) connection.send.assert_called_once_with( text_format.Parse( """actions: { key: 1, value: { int32s: { array: 123 } } }""", dm_env_rpc_pb2.StepRequest()))
def test_spec_generate_value_step(self): self._connection.send = mock.MagicMock( return_value=_SAMPLE_STEP_RESPONSE) action_spec = self._env.action_spec() actions = { name: spec.generate_value() for name, spec in action_spec.items() } self._env.step(actions) self._connection.send.assert_called_once_with( dm_env_rpc_pb2.StepRequest( requested_observations=[1, 2], actions={ 1: tensor_utils.pack_tensor(actions['foo']), 2: tensor_utils.pack_tensor(actions['bar'], dtype=np.str_) }))
def test_requested_observations(self): requested_observations = ['foo'] filtered_env = dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC, requested_observations) expected_filtered_step_request = dm_env_rpc_pb2.StepRequest( requested_observations=[1], actions={ 1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8), 2: tensor_utils.pack_tensor('hello') }) self._connection.send = mock.MagicMock(return_value=_SAMPLE_STEP_RESPONSE) filtered_env.step({'foo': 4, 'bar': 'hello'}) self._connection.send.assert_called_once_with( expected_filtered_step_request)
def test_nested_observations_step(self): connection = mock.MagicMock() connection.send = mock.MagicMock(return_value=text_format.Parse( """state: RUNNING observations: { key: 1, value: { int32s: { array: 42 } } }""", dm_env_rpc_pb2.StepResponse())) expected = {'foo': {'bar': 42}} env = dm_env_adaptor.DmEnvAdaptor(connection, specs=_SAMPLE_NESTED_SPECS, requested_observations=['foo.bar']) timestep = env.step({}) self.assertEqual(dm_env.StepType.FIRST, timestep.step_type) self.assertSameElements(expected, timestep.observation) connection.send.assert_called_once_with( dm_env_rpc_pb2.StepRequest(requested_observations=[1]))
def test_cannot_step_when_not_joined(self): with self.assertRaises(error.DmEnvRpcError): self._connection.send(dm_env_rpc_pb2.StepRequest())
def step(self, **kwargs): """Sends a StepRequest and returns the StepResponse.""" return self.connection.send(dm_env_rpc_pb2.StepRequest(**kwargs))
"""Tests for dm_env_rpc/dm_env adaptor.""" from absl.testing import absltest from absl.testing import parameterized import dm_env from dm_env import specs import mock import numpy as np from dm_env_rpc.v1 import dm_env_adaptor from dm_env_rpc.v1 import dm_env_rpc_pb2 from dm_env_rpc.v1 import tensor_utils _SAMPLE_STEP_REQUEST = dm_env_rpc_pb2.StepRequest( requested_observations=[1, 2], actions={ 1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8), 2: tensor_utils.pack_tensor('hello') }) _SAMPLE_STEP_RESPONSE = dm_env_rpc_pb2.StepResponse( state=dm_env_rpc_pb2.EnvironmentStateType.RUNNING, observations={ 1: tensor_utils.pack_tensor(5, dtype=dm_env_rpc_pb2.UINT8), 2: tensor_utils.pack_tensor('goodbye') }) _TERMINATED_STEP_RESPONSE = dm_env_rpc_pb2.StepResponse( state=dm_env_rpc_pb2.EnvironmentStateType.TERMINATED, observations={ 1: tensor_utils.pack_tensor(5, dtype=dm_env_rpc_pb2.UINT8), 2: tensor_utils.pack_tensor('goodbye') }) _SAMPLE_SPEC = dm_env_rpc_pb2.ActionObservationSpecs(