コード例 #1
0
ファイル: tensor_utils_test.py プロジェクト: rsfb/dm_env_rpc
    def test_unpack_scalar_proto(self):
        scalar = struct_pb2.Value(string_value='my message')
        tensor = tensor_utils.pack_tensor(scalar)

        unpacked = struct_pb2.Value()
        tensor_utils.unpack_tensor(tensor).Unpack(unpacked)
        self.assertEqual(scalar, unpacked)
コード例 #2
0
 def test_negative_dimension_in_matrix(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [1, 2, 3, 4, 5, 6]
     tensor.shape[:] = [2, -1]
     unpacked = tensor_utils.unpack_tensor(tensor)
     expected = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
     np.testing.assert_array_equal(expected, unpacked)
コード例 #3
0
ファイル: spec_manager.py プロジェクト: deepmind/dm_env_rpc
    def unpack(
        self, dm_env_rpc_tensors: Mapping[int, dm_env_rpc_pb2.Tensor]
    ) -> MutableMapping[str, Any]:
        """Unpacks a dm_env_rpc uid-to-tensor map to a name-keyed Python dict.

    Args:
      dm_env_rpc_tensors: A dict mapping UIDs to dm_env_rpc tensor protos.

    Returns:
      A dict mapping names to scalars and arrays.
    """
        unpacked = {}
        for uid, tensor in dm_env_rpc_tensors.items():
            name = self._uid_to_name[uid]
            dm_env_rpc_spec = self.name_to_spec(name)
            _assert_shapes_match(tensor, dm_env_rpc_spec)
            tensor_dtype = tensor_utils.get_tensor_type(tensor)
            spec_dtype = tensor_utils.data_type_to_np_type(
                dm_env_rpc_spec.dtype)
            if tensor_dtype != spec_dtype:
                raise ValueError(
                    'Received dm_env_rpc tensor {} with dtype {} but spec has dtype {}.'
                    .format(name, tensor_dtype, spec_dtype))
            tensor_unpacked = tensor_utils.unpack_tensor(tensor)
            unpacked[name] = tensor_unpacked
        return unpacked
コード例 #4
0
 def test_broadcasts_to_multidimensional_arrays(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [4]
     tensor.shape[:] = [2, 2]
     unpacked = tensor_utils.unpack_tensor(tensor)
     expected = np.array([[4, 4], [4, 4]], dtype=np.int32)
     np.testing.assert_array_equal(expected, unpacked)
コード例 #5
0
 def test_negative_dimension_single_element(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [1]
     tensor.shape[:] = [-1]
     unpacked = tensor_utils.unpack_tensor(tensor)
     expected = np.array([1], dtype=np.int32)
     np.testing.assert_array_equal(expected, unpacked)
コード例 #6
0
 def test_unsigned_integer_broadcasts_1_element_to_all_elements(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.uint8s.array = b'\x01'
     tensor.shape[:] = [4]
     unpacked = tensor_utils.unpack_tensor(tensor)
     expected = np.array([1, 1, 1, 1], dtype=np.uint8)
     np.testing.assert_array_equal(expected, unpacked)
コード例 #7
0
 def test_string_broadcasts_1_element_to_all_elements(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.strings.array[:] = ['foo']
     tensor.shape[:] = [4]
     unpacked = tensor_utils.unpack_tensor(tensor)
     expected = np.array(['foo', 'foo', 'foo', 'foo'], dtype=np.str_)
     np.testing.assert_array_equal(expected, unpacked)
コード例 #8
0
 def test_integer_broadcasts_1_element_to_all_elements(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [1]
     tensor.shape[:] = [4]
     unpacked = tensor_utils.unpack_tensor(tensor)
     expected = np.array([1, 1, 1, 1], dtype=np.int32)
     np.testing.assert_array_equal(expected, unpacked)
コード例 #9
0
 def test_unpack_multidimensional_arrays(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.floats.array[:] = [1, 2, 3, 4, 5, 6, 7, 8]
     tensor.shape[:] = [2, 4]
     round_trip = tensor_utils.unpack_tensor(tensor)
     expected = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
     np.testing.assert_array_equal(expected, round_trip)
コード例 #10
0
 def test_all_numerical_observations_in_range(self):
     numeric_uids = (uid for uid, spec in self.specs.observations.items()
                     if _is_numeric_type(spec.dtype))
     response = self.step(requested_observations=numeric_uids)
     for uid, observation in response.observations.items():
         spec = self.specs.observations[uid]
         with self.subTest(uid=uid, name=spec.name):
             unpacked = tensor_utils.unpack_tensor(observation)
             bounds = tensor_spec_utils.bounds(spec)
             _assert_less_equal(unpacked, bounds.max)
             _assert_greater_equal(unpacked, bounds.min)
コード例 #11
0
ファイル: tensor_utils_test.py プロジェクト: rsfb/dm_env_rpc
    def test_unpack_proto_arrays(self):
        array = np.array([
            struct_pb2.Value(string_value=message)
            for message in ['foo', 'bar']
        ])
        tensor = tensor_utils.pack_tensor(array)
        round_trip = tensor_utils.unpack_tensor(tensor)

        unpacked = struct_pb2.Value()
        round_trip[0].Unpack(unpacked)
        self.assertEqual(array[0], unpacked)
        round_trip[1].Unpack(unpacked)
        self.assertEqual(array[1], unpacked)
コード例 #12
0
ファイル: properties.py プロジェクト: rsfb/dm_env_rpc
    def read(self, key: str):
        """Reads the value of a property.

    Args:
      key: A string key that represents the property to read.

    Returns:
      The value of the property, either as a scalar (float, int, string, etc.)
      or, if the response tensor has a non-empty `shape` attribute, a NumPy
      array of the payload with the correct type and shape. See
      tensor_utils.unpack for more details.
    """
        response = properties_pb2.PropertyResponse()
        packed_request = any_pb2.Any()
        packed_request.Pack(
            properties_pb2.PropertyRequest(
                read_property=properties_pb2.ReadPropertyRequest(key=key)))
        self._connection.send(packed_request).Unpack(response)
        return tensor_utils.unpack_tensor(response.read_property.value)
コード例 #13
0
ファイル: docker.py プロジェクト: frangipane/dm_construction
 def read_property(self, name):
     """Read a property of the Unity environment."""
     properties = self._connection.send(
         dm_env_rpc_pb2.ReadPropertyRequest(keys=[name])).properties
     return tensor_utils.unpack_tensor(properties[name])
コード例 #14
0
 def test_scalar_with_too_many_elements_raises_error(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [1, 2, 3]
     with self.assertRaisesRegex(ValueError, '3 element'):
         tensor_utils.unpack_tensor(tensor)
コード例 #15
0
 def test_unknown_type_raises_error(self):
     tensor = mock.MagicMock()
     tensor.WhichOneof.return_value = 'foo'
     with self.assertRaisesRegex(TypeError, 'type foo'):
         tensor_utils.unpack_tensor(tensor)
コード例 #16
0
 def test_unpack_arrays(self, array):
     tensor = tensor_utils.pack_tensor(array)
     round_trip = tensor_utils.unpack_tensor(tensor)
     np.testing.assert_array_equal(array, round_trip)
コード例 #17
0
 def test_too_many_elements(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.floats.array[:] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
     tensor.shape[:] = [2, 4]
     with self.assertRaisesRegex(ValueError, 'cannot reshape array'):
         tensor_utils.unpack_tensor(tensor)
コード例 #18
0
    def Process(self, request_iterator, context):
        """Processes incoming EnvironmentRequests.

    For each EnvironmentRequest the internal message is extracted and handled.
    The response for that message is then placed in a EnvironmentResponse which
    is returned to the client.

    An error status will be returned if an unknown message type is received or
    if the message is invalid for the current world state.


    Args:
      request_iterator: Message iterator provided by gRPC.
      context: Context provided by gRPC.

    Yields:
      EnvironmentResponse: Response for each incoming EnvironmentRequest.
    """

        env_factory = CatchGameFactory(_INITIAL_SEED)
        env = None
        is_joined = False
        skip_next_frame = False
        action_manager = spec_manager.SpecManager(_action_spec())
        observation_manager = spec_manager.SpecManager(_observation_spec())

        for request in request_iterator:
            environment_response = dm_env_rpc_pb2.EnvironmentResponse()
            try:
                message_type = request.WhichOneof('payload')
                internal_request = getattr(request, message_type)
                _check_message_type(env, is_joined, message_type)

                if message_type == 'create_world':
                    _validate_settings(
                        request.create_world.settings,
                        valid_settings=_VALID_CREATE_AND_RESET_SETTINGS)
                    seed = request.create_world.settings.get('seed', None)
                    if seed is not None:
                        env_factory.reset_seed(
                            tensor_utils.unpack_tensor(seed))
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.CreateWorldResponse(
                        world_name=_WORLD_NAME)
                elif message_type == 'join_world':
                    _validate_settings(request.join_world.settings,
                                       valid_settings=[])
                    if is_joined:
                        raise RuntimeError(
                            f'Tried to join world "{internal_request.world_name}" but '
                            f'already joined to world "{_WORLD_NAME}"')
                    if internal_request.world_name != _WORLD_NAME:
                        raise RuntimeError(
                            f'Tried to join world "{internal_request.world_name}" but the '
                            f'only supported world is "{_WORLD_NAME}"')
                    response = dm_env_rpc_pb2.JoinWorldResponse()
                    for uid, action in _action_spec().items():
                        response.specs.actions[uid].CopyFrom(action)
                    for uid, observation in _observation_spec().items():
                        response.specs.observations[uid].CopyFrom(observation)
                    is_joined = True
                elif message_type == 'step':
                    # We need to skip all actions after creating or resetting the
                    # environment.
                    if skip_next_frame:
                        skip_next_frame = False
                    else:
                        unpacked_actions = action_manager.unpack(
                            internal_request.actions)
                        paddle_action = unpacked_actions.get(
                            _ACTION_PADDLE, _DEFAULT_ACTION)
                        if paddle_action not in _VALID_ACTIONS:
                            raise RuntimeError(
                                f'Invalid paddle action value: "{paddle_action}"!'
                            )
                        env.update(paddle_action)

                    response = dm_env_rpc_pb2.StepResponse()
                    packed_observations = observation_manager.pack({
                        _OBSERVATION_BOARD:
                        env.draw_board(),
                        _OBSERVATION_REWARD:
                        env.reward()
                    })

                    for requested_observation in internal_request.requested_observations:
                        response.observations[requested_observation].CopyFrom(
                            packed_observations[requested_observation])
                    if env.has_terminated():
                        response.state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED
                    else:
                        response.state = dm_env_rpc_pb2.EnvironmentStateType.RUNNING

                    if env.has_terminated():
                        env = env_factory.new_game()
                        skip_next_frame = True
                elif message_type == 'reset':
                    _validate_settings(
                        request.reset.settings,
                        valid_settings=_VALID_CREATE_AND_RESET_SETTINGS)
                    seed = request.reset.settings.get('seed', None)
                    if seed is not None:
                        env_factory.reset_seed(
                            tensor_utils.unpack_tensor(seed))
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.ResetResponse()
                    for uid, action in _action_spec().items():
                        response.specs.actions[uid].CopyFrom(action)
                    for uid, observation in _observation_spec().items():
                        response.specs.observations[uid].CopyFrom(observation)
                elif message_type == 'reset_world':
                    _validate_settings(
                        request.reset_world.settings,
                        valid_settings=_VALID_CREATE_AND_RESET_SETTINGS)
                    seed = request.reset_world.settings.get('seed', None)
                    if seed is not None:
                        env_factory.reset_seed(
                            tensor_utils.unpack_tensor(seed))
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.ResetWorldResponse()
                elif message_type == 'leave_world':
                    is_joined = False
                    response = dm_env_rpc_pb2.LeaveWorldResponse()
                elif message_type == 'destroy_world':
                    if internal_request.world_name != _WORLD_NAME:
                        raise RuntimeError(
                            'Tried to destroy world "{}" but we only support world "{}"'
                            .format(internal_request.world_name, _WORLD_NAME))
                    env = None
                    response = dm_env_rpc_pb2.DestroyWorldResponse()
                else:
                    raise RuntimeError(
                        'Unhandled message: {}'.format(message_type))
                getattr(environment_response, message_type).CopyFrom(response)
            except Exception as e:  # pylint: disable=broad-except
                environment_response.error.CopyFrom(
                    status_pb2.Status(message=str(e)))

            yield environment_response
コード例 #19
0
 def test_two_negative_dimensions_in_matrix(self):
     tensor = dm_env_rpc_pb2.Tensor()
     tensor.int32s.array[:] = [1, 2, 3, 4, 5, 6]
     tensor.shape[:] = [-1, -2]
     with self.assertRaisesRegex(ValueError, 'one unknown dimension'):
         tensor_utils.unpack_tensor(tensor)
コード例 #20
0
 def test_unpack_scalars(self, scalar):
     tensor = tensor_utils.pack_tensor(scalar)
     round_trip = tensor_utils.unpack_tensor(tensor)
     self.assertEqual(scalar, round_trip)