コード例 #1
0
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running dm_hard_eight environment."""
    if settings.level_name not in HARD_EIGHT_TASK_LEVEL_NAMES:
        raise ValueError(
            'Level named "{}" is not supported for dm_hard_eight'.format(
                settings.level_name))
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings={
                'seed': tensor_utils.pack_tensor(settings.seed),
                'episodeId': tensor_utils.pack_tensor(0),
                'levelName': tensor_utils.pack_tensor(settings.level_name),
            })).world_name
    join_world_settings = {
        'width': tensor_utils.pack_tensor(settings.width),
        'height': tensor_utils.pack_tensor(settings.height),
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs
    return _ConnectionDetails(channel=channel,
                              connection=connection,
                              specs=specs)
コード例 #2
0
ファイル: docker.py プロジェクト: frangipane/dm_construction
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running dm_construction environment."""
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )

    all_settings = {
        key: tensor_utils.pack_tensor(val)
        for key, val in settings.items()
    }

    create_settings = {
        "levelName": all_settings["levelName"],
        "seed": tensor_utils.pack_tensor(0),
        "episodeId": tensor_utils.pack_tensor(0)
    }
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(settings=create_settings)).world_name

    join_settings = all_settings.copy()
    del join_settings["levelName"]
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_settings)).specs

    return channel, connection, specs
コード例 #3
0
 def test_unpack(self):
     unpacked = self._spec_manager.unpack({
         54:
         tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32),
         55:
         tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32),
     })
     self.assertLen(unpacked, 2)
     np.testing.assert_array_equal(np.asarray([1.0, 2.0]), unpacked['fuzz'])
     np.testing.assert_array_equal(np.asarray([3, 4, 5]), unpacked['foo'])
コード例 #4
0
 def test_pack(self):
     packed = self._spec_manager.pack({
         'fuzz': [1.0, 2.0],
         'foo': [3, 4, 5]
     })
     expected = {
         54: tensor_utils.pack_tensor([1.0, 2.0], dtype=np.float32),
         55: tensor_utils.pack_tensor([3, 4, 5], dtype=np.int32),
     }
     self.assertDictEqual(expected, packed)
コード例 #5
0
ファイル: catch_test.py プロジェクト: rsfb/dm_env_rpc
    def test_reset_seed_setting(self):
        self._world_name = self._connection.send(
            dm_env_rpc_pb2.CreateWorldRequest(
                settings={'seed': tensor_utils.pack_tensor(1234)})).world_name
        self._connection.send(
            dm_env_rpc_pb2.JoinWorldRequest(world_name=self._world_name))

        step_response = self._connection.send(dm_env_rpc_pb2.StepRequest())
        self._connection.send(
            dm_env_rpc_pb2.ResetRequest(
                settings={'seed': tensor_utils.pack_tensor(1234)}))
        self.assertEqual(step_response,
                         self._connection.send(dm_env_rpc_pb2.StepRequest()))
コード例 #6
0
ファイル: tensor_utils_test.py プロジェクト: rsfb/dm_env_rpc
    def test_unpack_scalar_proto(self):
        scalar = struct_pb2.Value(string_value='my message')
        tensor = tensor_utils.pack_tensor(scalar)

        unpacked = struct_pb2.Value()
        tensor_utils.unpack_tensor(tensor).Unpack(unpacked)
        self.assertEqual(scalar, unpacked)
コード例 #7
0
 def test_spec_generate_value_step(self):
     self._connection.send = mock.MagicMock(
         return_value=_SAMPLE_STEP_RESPONSE)
     action_spec = self._env.action_spec()
     actions = {
         name: spec.generate_value()
         for name, spec in action_spec.items()
     }
     self._env.step(actions)
     self._connection.send.assert_called_once_with(
         dm_env_rpc_pb2.StepRequest(
             requested_observations=[1, 2],
             actions={
                 1: tensor_utils.pack_tensor(actions['foo']),
                 2: tensor_utils.pack_tensor(actions['bar'], dtype=np.str_)
             }))
コード例 #8
0
ファイル: tensor_utils_test.py プロジェクト: rsfb/dm_env_rpc
 def test_pack_scalar_protos(self):
     scalar = struct_pb2.Value(string_value='my message')
     tensor = tensor_utils.pack_tensor(scalar)
     self.assertEqual([], tensor.shape)
     self.assertLen(tensor.protos.array, 1)
     unpacked = struct_pb2.Value()
     self.assertTrue(tensor.protos.array[0].Unpack(unpacked))
     self.assertEqual(scalar, unpacked)
コード例 #9
0
  def test_requested_observations(self):
    requested_observations = ['foo']
    filtered_env = dm_env_adaptor.DmEnvAdaptor(self._connection, _SAMPLE_SPEC,
                                               requested_observations)

    expected_filtered_step_request = dm_env_rpc_pb2.StepRequest(
        requested_observations=[1],
        actions={
            1: tensor_utils.pack_tensor(4, dtype=dm_env_rpc_pb2.UINT8),
            2: tensor_utils.pack_tensor('hello')
        })

    self._connection.send = mock.MagicMock(return_value=_SAMPLE_STEP_RESPONSE)
    filtered_env.step({'foo': 4, 'bar': 'hello'})

    self._connection.send.assert_called_once_with(
        expected_filtered_step_request)
コード例 #10
0
def _create_test_tensor(spec, dtype=None):
    """Creates an arbitrary tensor consistent with the TensorSpec `spec`."""
    shape = np.asarray(spec.shape)
    shape[shape < 0] = 1
    value = [_create_test_value(spec)] * int(np.prod(shape))
    tensor = tensor_utils.pack_tensor(value, dtype=dtype or spec.dtype)
    tensor.shape[:] = shape
    return tensor
コード例 #11
0
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running dm_fast_mapping environment."""
    if settings.level_name not in FAST_MAPPING_TASK_LEVEL_NAMES:
        raise ValueError(
            'Level named "{}" is not a valid dm_fast_mapping level.'.format(
                settings.level_name))
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings={
                'seed': tensor_utils.pack_tensor(settings.seed),
                'episodeId': tensor_utils.pack_tensor(0),
                'levelName': tensor_utils.pack_tensor(settings.level_name),
            })).world_name
    join_world_settings = {
        'width':
        tensor_utils.pack_tensor(settings.width),
        'height':
        tensor_utils.pack_tensor(settings.height),
        'EpisodeLengthSeconds':
        tensor_utils.pack_tensor(settings.episode_length_seconds),
        'ShowReachabilityHUD':
        tensor_utils.pack_tensor(False),
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs
    return _ConnectionDetails(channel=channel,
                              connection=connection,
                              specs=specs)
コード例 #12
0
 def test_cannot_send_wrong_type_to_nonnumeric_actions(self):
     tensor = tensor_utils.pack_tensor(0, dtype=np.int32)
     for uid, spec in self.nonnumeric_actions.items():
         with self.subTest(uid=uid, name=spec.name):
             shape = np.asarray(spec.shape)
             shape[shape < 0] = 1
             tensor.shape[:] = shape
             with self.assertRaises(error.DmEnvRpcError):
                 self.step(actions={uid: tensor})
コード例 #13
0
 def test_can_send_broadcastable_actions(self):
     for uid, spec in self.specs.actions.items():
         with self.subTest(uid=uid, name=spec.name):
             tensor = tensor_utils.pack_tensor(_create_test_value(spec),
                                               dtype=spec.dtype)
             shape = np.asarray(spec.shape)
             shape[shape < 0] = 1
             tensor.shape[:] = shape
             self.step(actions={uid: tensor})
コード例 #14
0
ファイル: tensor_utils_test.py プロジェクト: rsfb/dm_env_rpc
 def test_pack_proto_arrays(self):
     array = np.array([
         struct_pb2.Value(string_value=message)
         for message in ['foo', 'bar']
     ])
     tensor = tensor_utils.pack_tensor(array)
     self.assertEqual([2], tensor.shape)
     unpacked = struct_pb2.Value()
     tensor.protos.array[0].Unpack(unpacked)
     self.assertEqual(array[0], unpacked)
     tensor.protos.array[1].Unpack(unpacked)
     self.assertEqual(array[1], unpacked)
コード例 #15
0
def _connect_to_environment(port, settings):
    """Helper function for connecting to a running Alchemy environment."""
    if settings.level_name not in ALCHEMY_LEVEL_NAMES:
        raise ValueError(
            'Level named "{}" is not a valid dm_alchemy level.'.format(
                settings.level_name))
    channel, connection = _create_channel_and_connection(port)
    original_send = connection.send
    connection.send = lambda request: _wrap_send(lambda: original_send(request)
                                                 )
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings={
                'seed': tensor_utils.pack_tensor(settings.seed),
                'episodeId': tensor_utils.pack_tensor(0),
                'levelName': tensor_utils.pack_tensor(settings.level_name),
                'EventSubscriptionRegex': tensor_utils.pack_tensor(
                    'DeepMind/.*'),
            })).world_name
    join_world_settings = {
        'width': tensor_utils.pack_tensor(settings.width),
        'height': tensor_utils.pack_tensor(settings.height),
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs
    return _ConnectionDetails(channel=channel,
                              connection=connection,
                              specs=specs)
コード例 #16
0
 def test_cannot_send_action_above_max(self):
     for uid, spec in self.numeric_actions.items():
         with self.subTest(uid=uid, name=spec.name):
             above = _above_max(spec)
             if above is None:
                 # There are no values above spec's max.
                 continue
             tensor = tensor_utils.pack_tensor(above, dtype=spec.dtype)
             shape = np.asarray(spec.shape)
             shape[shape < 0] = 1
             tensor.shape[:] = shape
             with self.assertRaises(error.DmEnvRpcError):
                 self.step(actions={uid: tensor})
コード例 #17
0
ファイル: tensor_utils_test.py プロジェクト: rsfb/dm_env_rpc
    def test_unpack_proto_arrays(self):
        array = np.array([
            struct_pb2.Value(string_value=message)
            for message in ['foo', 'bar']
        ])
        tensor = tensor_utils.pack_tensor(array)
        round_trip = tensor_utils.unpack_tensor(tensor)

        unpacked = struct_pb2.Value()
        round_trip[0].Unpack(unpacked)
        self.assertEqual(array[0], unpacked)
        round_trip[1].Unpack(unpacked)
        self.assertEqual(array[1], unpacked)
コード例 #18
0
    def test_create_join_world_with_packed_settings(self):
        connection = mock.MagicMock()
        connection.send = mock.MagicMock(side_effect=[
            dm_env_rpc_pb2.CreateWorldResponse(world_name='Magrathea_02'),
            dm_env_rpc_pb2.JoinWorldResponse(specs=_SAMPLE_SPEC)
        ])
        env, world_name = dm_env_adaptor.create_and_join_world(
            connection,
            create_world_settings={
                'planet': tensor_utils.pack_tensor('Magrathea')
            },
            join_world_settings={
                'ship_type': tensor_utils.pack_tensor(2),
                'player': tensor_utils.pack_tensor('arthur'),
                'unpacked_setting': [1, 2, 3],
            })
        self.assertIsNotNone(env)
        self.assertEqual('Magrathea_02', world_name)

        connection.send.assert_has_calls([
            mock.call(
                text_format.Parse(
                    """settings: {
                key: 'planet', value: { strings: { array: 'Magrathea' } }
            }""", dm_env_rpc_pb2.CreateWorldRequest())),
            mock.call(
                text_format.Parse(
                    """world_name: 'Magrathea_02'
                settings: { key: 'ship_type', value: { int64s: { array: 2 } } }
                settings: {
                    key: 'player', value: { strings: { array: 'arthur' } }
                }
                settings: {
                    key: 'unpacked_setting', value: {
                      int64s: { array: 1 array: 2 array: 3 }
                      shape: 3
                    }
                }""", dm_env_rpc_pb2.JoinWorldRequest())),
        ])
コード例 #19
0
ファイル: properties.py プロジェクト: rsfb/dm_env_rpc
    def write(self, key: str, value) -> None:
        """Writes the provided value to a property.

    Args:
      key: A string key that represents the property to write.
      value: A scalar (float, int, string, etc.), NumPy array, or nested lists.
        See tensor_utils.pack for more details.
    """
        packed_request = any_pb2.Any()
        packed_request.Pack(
            properties_pb2.PropertyRequest(
                write_property=properties_pb2.WritePropertyRequest(
                    key=key, value=tensor_utils.pack_tensor(value))))
        self._connection.send(packed_request)
コード例 #20
0
ファイル: spec_manager.py プロジェクト: rsfb/dm_env_rpc
  def pack(self, tensors):
    """Packs a name-keyed Python dict to a dm_env_rpc uid-to-tensor map.

    Args:
      tensors: A dict mapping string names to scalars and arrays.

    Returns:
      A dict mapping UIDs to dm_env_rpc tensor protos.
    """
    packed = {}
    for name, value in tensors.items():
      dm_env_rpc_spec = self.name_to_spec(name)
      tensor = tensor_utils.pack_tensor(value, dtype=dm_env_rpc_spec.dtype)
      _assert_shapes_match(tensor, dm_env_rpc_spec)
      packed[self.name_to_uid(name)] = tensor
    return packed
コード例 #21
0
def create_and_join_world(
    connection: dm_env_rpc_connection.Connection,
    create_world_settings: Mapping[str, Any],
    join_world_settings: Mapping[str, Any],
    requested_observations: Optional[Iterable[str]] = None,
    extensions: Optional[Mapping[str, Any]] = frozendict.frozendict()
) -> Tuple[DmEnvAdaptor, str]:
    """Helper function to create and join a world with the provided settings.

  Args:
    connection: An instance of Connection already connected to a dm_env_rpc
      server.
    create_world_settings: Settings used to create the world. Values must be
      packable into a Tensor proto or already packed.
    join_world_settings: Settings used to join the world. Values must be
      packable into a Tensor message.
    requested_observations: Optional set of requested observations.
    extensions: Optional mapping of extension instances to DmEnvAdaptor
      attributes.

  Returns:
    Tuple of DmEnvAdaptor and the created world name.
  """
    create_world_settings = {
        key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
              tensor_utils.pack_tensor(value))
        for key, value in create_world_settings.items()
    }
    world_name = connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings=create_world_settings)).world_name
    try:
        return_type = collections.namedtuple('DmEnvAndWorldName',
                                             ['env', 'world_name'])
        return return_type(
            join_world(connection, world_name, join_world_settings,
                       requested_observations, extensions), world_name)
    except (error.DmEnvRpcError, ValueError):
        connection.send(
            dm_env_rpc_pb2.DestroyWorldRequest(world_name=world_name))
        raise
コード例 #22
0
def join_world(
    connection: dm_env_rpc_connection.Connection,
    world_name: str,
    join_world_settings: Mapping[str, Any],
    requested_observations: Optional[Iterable[str]] = None,
    extensions: Optional[Mapping[str, Any]] = immutabledict.immutabledict()
) -> DmEnvAdaptor:
    """Helper function to join a world with the provided settings.

  Args:
    connection: An instance of Connection already connected to a dm_env_rpc
      server.
    world_name: Name of the world to join.
    join_world_settings: Settings used to join the world. Values must be
      packable into a Tensor message or already packed.
    requested_observations: Optional set of requested observations.
    extensions: Optional mapping of extension instances to DmEnvAdaptor
      attributes.

  Returns:
    Instance of DmEnvAdaptor.
  """

    join_world_settings = {
        key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
              tensor_utils.pack_tensor(value))
        for key, value in join_world_settings.items()
    }
    specs = connection.send(
        dm_env_rpc_pb2.JoinWorldRequest(world_name=world_name,
                                        settings=join_world_settings)).specs

    try:
        return DmEnvAdaptor(connection,
                            specs,
                            requested_observations,
                            extensions=extensions)
    except ValueError:
        connection.send(dm_env_rpc_pb2.LeaveWorldRequest())
        raise
コード例 #23
0
def create_world(connection: dm_env_rpc_connection.Connection,
                 create_world_settings: Mapping[str, Any]) -> str:
    """Helper function to create a world with the provided settings.

  Args:
    connection: An instance of Connection already connected to a dm_env_rpc
      server.
    create_world_settings: Settings used to create the world. Values must be
      packable into a Tensor proto or already packed.

  Returns:
    Created world name.
  """

    create_world_settings = {
        key: (value if isinstance(value, dm_env_rpc_pb2.Tensor) else
              tensor_utils.pack_tensor(value))
        for key, value in create_world_settings.items()
    }
    return connection.send(
        dm_env_rpc_pb2.CreateWorldRequest(
            settings=create_world_settings)).world_name
コード例 #24
0
def _make_environment_connection(channel, connection, settings):
  """Helper function for connecting to a running dm_memorytask environment."""
  original_send = connection.send
  connection.send = lambda request: _wrap_send(lambda: original_send(request))
  world_name = connection.send(
      dm_env_rpc_pb2.CreateWorldRequest(
          settings={
              'seed': tensor_utils.pack_tensor(settings.seed),
              'episodeId': tensor_utils.pack_tensor(0),
              'levelName': tensor_utils.pack_tensor(settings.level_name),
          })).world_name
  join_world_settings = {
      'width':
          tensor_utils.pack_tensor(settings.width),
      'height':
          tensor_utils.pack_tensor(settings.height),
      'EpisodeLengthSeconds':
          tensor_utils.pack_tensor(settings.episode_length_seconds),
  }
  specs = connection.send(
      dm_env_rpc_pb2.JoinWorldRequest(
          world_name=world_name, settings=join_world_settings)).specs
  return _ConnectionDetails(channel=channel, connection=connection, specs=specs)
コード例 #25
0
 def statement(self):
     self._unpacked.flat[0] += 1  # prevent caching of the result
     tensor_utils.pack_tensor(self._unpacked, self._dtype)
コード例 #26
0
 def test_cannot_send_invalid_action_uid(self):
     bad_uid = _find_uid_not_in_set(self.action_uids)
     with self.assertRaises(error.DmEnvRpcError):
         self.step(actions={bad_uid: tensor_utils.pack_tensor(0)})
コード例 #27
0
 def setup(self):
     # Use non-zero values in case there's something special about zero arrays.
     tensor = np.arange(np.prod(self._shape),
                        dtype=self._dtype).reshape(self._shape)
     self._packed = tensor_utils.pack_tensor(tensor, self._dtype)
コード例 #28
0
 def test_first_step_actions_are_ignored(self):
     bad_uid = _find_uid_not_in_set(self.action_uids)
     self.step(actions={bad_uid: tensor_utils.pack_tensor(0)})
コード例 #29
0
import contextlib

from absl.testing import absltest
import grpc
import mock

from google.protobuf import any_pb2
from google.protobuf import struct_pb2
from google.rpc import status_pb2
from dm_env_rpc.v1 import connection as dm_env_rpc_connection
from dm_env_rpc.v1 import dm_env_rpc_pb2
from dm_env_rpc.v1 import error
from dm_env_rpc.v1 import tensor_utils

_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest(
    settings={'foo': tensor_utils.pack_tensor('bar')})
_CREATE_RESPONSE = dm_env_rpc_pb2.CreateWorldResponse()

_BAD_CREATE_REQUEST = dm_env_rpc_pb2.CreateWorldRequest()
_TEST_ERROR = dm_env_rpc_pb2.EnvironmentResponse(error=status_pb2.Status(
    message='A test error.'))

_INCORRECT_RESPONSE_TEST_MSG = dm_env_rpc_pb2.DestroyWorldRequest(
    world_name='foo')
_INCORRECT_RESPONSE = dm_env_rpc_pb2.EnvironmentResponse(
    leave_world=dm_env_rpc_pb2.LeaveWorldResponse())

_EXTENSION_REQUEST = struct_pb2.Value(string_value='extension request')
_EXTENSION_RESPONSE = struct_pb2.Value(number_value=555)

コード例 #30
0
ファイル: docker.py プロジェクト: frangipane/dm_construction
 def write_property(self, name, value):
     """Write a property of the Unity environment."""
     properties = {name: tensor_utils.pack_tensor(value)}
     self._connection.send(
         dm_env_rpc_pb2.WritePropertyRequest(properties=properties))