Ejemplo n.º 1
0
  def __init__(self, connection, specs, requested_observations=None):
    """Initializes the environment with the provided dm_env_rpc connection.

    Args:
      connection: An instance of Connection already connected to a dm_env_rpc
        server and after a successful JoinWorldRequest has been sent.
      specs: A dm_env_rpc ActionObservationSpecs message for the environment.
      requested_observations: The observation names to be requested from the
        environment when step is called. If None is specified then all
        observations will be requested.
    """
    self._dm_env_rpc_specs = specs
    self._action_specs = spec_manager.SpecManager(specs.actions)
    self._observation_specs = spec_manager.SpecManager(specs.observations)
    self._connection = connection
    self._last_state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED

    if requested_observations is None:
      requested_observations = self._observation_specs.names()
    requested_observations = set(requested_observations)

    self._is_reward_requested = DEFAULT_REWARD_KEY in requested_observations
    self._is_discount_requested = DEFAULT_DISCOUNT_KEY in requested_observations

    self._default_reward_spec = None
    self._default_discount_spec = None
    if DEFAULT_REWARD_KEY in self._observation_specs.names():
      self._default_reward_spec = dm_env_utils.tensor_spec_to_dm_env_spec(
          self._observation_specs.name_to_spec(DEFAULT_REWARD_KEY))
      requested_observations.add(DEFAULT_REWARD_KEY)
    if DEFAULT_DISCOUNT_KEY in self._observation_specs.names():
      self._default_discount_spec = (
          dm_env_utils.tensor_spec_to_dm_env_spec(
              self._observation_specs.name_to_spec(DEFAULT_DISCOUNT_KEY)))
      requested_observations.add(DEFAULT_DISCOUNT_KEY)

    unsupported_observations = requested_observations.difference(
        self._observation_specs.names())
    if unsupported_observations:
      raise ValueError('Unsupported observations requested: {}'.format(
          unsupported_observations))
    self._requested_observation_uids = [
        self._observation_specs.name_to_uid(name)
        for name in requested_observations
    ]

    # Not strictly necessary but it makes the unit tests deterministic.
    self._requested_observation_uids.sort()
Ejemplo n.º 2
0
 def test_invalid_variable_spec_shape(self):
     with self.assertRaisesRegex(ValueError,
                                 'shape has > 1 variable length'):
         spec_manager.SpecManager({
             1:
             dm_env_rpc_pb2.TensorSpec(name='bar',
                                       shape=[1, -1, -1],
                                       dtype=dm_env_rpc_pb2.DataType.INT32)
         })
Ejemplo n.º 3
0
 def test_empty_variable_shape(self):
     manager = spec_manager.SpecManager({
         1:
         dm_env_rpc_pb2.TensorSpec(name='bar',
                                   shape=[],
                                   dtype=dm_env_rpc_pb2.DataType.INT32)
     })
     with self.assertRaisesRegex(ValueError, 'shape'):
         manager.pack({'bar': np.ones((1), dtype=np.int32)})
Ejemplo n.º 4
0
 def setUp(self):
     super(SpecManagerVariableSpecShapeTests, self).setUp()
     specs = {
         101:
         dm_env_rpc_pb2.TensorSpec(name='foo',
                                   shape=[1, -1],
                                   dtype=dm_env_rpc_pb2.DataType.INT32),
     }
     self._spec_manager = spec_manager.SpecManager(specs)
Ejemplo n.º 5
0
 def setUp(self):
     super(SpecManagerTests, self).setUp()
     specs = {
         54:
         dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                   shape=[2],
                                   dtype=dm_env_rpc_pb2.DataType.FLOAT),
         55:
         dm_env_rpc_pb2.TensorSpec(name='foo',
                                   shape=[3],
                                   dtype=dm_env_rpc_pb2.DataType.INT32),
     }
     self._spec_manager = spec_manager.SpecManager(specs)
Ejemplo n.º 6
0
 def test_duplicate_names_raise_error(self):
     specs = {
         54:
         dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                   shape=[3],
                                   dtype=dm_env_rpc_pb2.DataType.FLOAT),
         55:
         dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                   shape=[2],
                                   dtype=dm_env_rpc_pb2.DataType.FLOAT),
     }
     with self.assertRaisesRegex(ValueError, 'duplicate name'):
         spec_manager.SpecManager(specs)
Ejemplo n.º 7
0
    def test_spec(self):
        dm_env_rpc_specs = {
            54:
            dm_env_rpc_pb2.TensorSpec(name='fuzz',
                                      shape=[3],
                                      dtype=dm_env_rpc_pb2.DataType.FLOAT),
            55:
            dm_env_rpc_pb2.TensorSpec(name='foo',
                                      shape=[2],
                                      dtype=dm_env_rpc_pb2.DataType.INT32),
        }
        manager = spec_manager.SpecManager(dm_env_rpc_specs)

        expected = {
            'foo': specs.Array(shape=[2], dtype=np.int32),
            'fuzz': specs.Array(shape=[3], dtype=np.float32)
        }

        self.assertDictEqual(expected, dm_env_utils.dm_env_spec(manager))
Ejemplo n.º 8
0
    def Process(self, request_iterator, context):
        """Processes incoming EnvironmentRequests.

    For each EnvironmentRequest the internal message is extracted and handled.
    The response for that message is then placed in a EnvironmentResponse which
    is returned to the client.

    An error status will be returned if an unknown message type is received or
    if the message is invalid for the current world state.


    Args:
      request_iterator: Message iterator provided by gRPC.
      context: Context provided by gRPC.

    Yields:
      EnvironmentResponse: Response for each incoming EnvironmentRequest.
    """

        env_factory = CatchGameFactory(_INITIAL_SEED)
        env = None
        is_joined = False
        skip_next_frame = False
        action_manager = spec_manager.SpecManager(_action_spec())
        observation_manager = spec_manager.SpecManager(_observation_spec())

        for request in request_iterator:
            environment_response = dm_env_rpc_pb2.EnvironmentResponse()
            try:
                message_type = request.WhichOneof('payload')
                internal_request = getattr(request, message_type)
                _check_message_type(env, is_joined, message_type)

                if message_type == 'create_world':
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.CreateWorldResponse(
                        world_name=_WORLD_NAME)
                elif message_type == 'join_world':
                    if is_joined:
                        raise RuntimeError(
                            f'Tried to join world "{internal_request.world_name}" but '
                            f'already joined to world "{_WORLD_NAME}"')
                    if internal_request.world_name != _WORLD_NAME:
                        raise RuntimeError(
                            f'Tried to join world "{internal_request.world_name}" but the '
                            f'only supported world is "{_WORLD_NAME}"')
                    response = dm_env_rpc_pb2.JoinWorldResponse()
                    for uid, action in _action_spec().items():
                        response.specs.actions[uid].CopyFrom(action)
                    for uid, observation in _observation_spec().items():
                        response.specs.observations[uid].CopyFrom(observation)
                    is_joined = True
                elif message_type == 'step':
                    # We need to skip all actions after creating or resetting the
                    # environment.
                    if skip_next_frame:
                        skip_next_frame = False
                    else:
                        unpacked_actions = action_manager.unpack(
                            internal_request.actions)
                        paddle_action = unpacked_actions.get(
                            _ACTION_PADDLE, _DEFAULT_ACTION)
                        env.update(paddle_action)

                    response = dm_env_rpc_pb2.StepResponse()
                    packed_observations = observation_manager.pack({
                        _OBSERVATION_BOARD:
                        env.draw_board(),
                        _OBSERVATION_REWARD:
                        env.reward()
                    })

                    for requested_observation in internal_request.requested_observations:
                        response.observations[requested_observation].CopyFrom(
                            packed_observations[requested_observation])
                    if env.has_terminated():
                        response.state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED
                    else:
                        response.state = dm_env_rpc_pb2.EnvironmentStateType.RUNNING

                    if env.has_terminated():
                        env = env_factory.new_game()
                        skip_next_frame = True
                elif message_type == 'reset':
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.ResetResponse()
                    for uid, action in _action_spec().items():
                        response.specs.actions[uid].CopyFrom(action)
                    for uid, observation in _observation_spec().items():
                        response.specs.observations[uid].CopyFrom(observation)
                elif message_type == 'reset_world':
                    env = env_factory.new_game()
                    skip_next_frame = True
                    response = dm_env_rpc_pb2.ResetWorldResponse()
                elif message_type == 'leave_world':
                    is_joined = False
                    response = dm_env_rpc_pb2.LeaveWorldResponse()
                elif message_type == 'destroy_world':
                    if internal_request.world_name != _WORLD_NAME:
                        raise RuntimeError(
                            'Tried to destroy world "{}" but we only support world "{}"'
                            .format(internal_request.world_name, _WORLD_NAME))
                    env = None
                    response = dm_env_rpc_pb2.DestroyWorldResponse()
                else:
                    raise RuntimeError(
                        'Unhandled message: {}'.format(message_type))
                getattr(environment_response, message_type).CopyFrom(response)
            except Exception as e:  # pylint: disable=broad-except
                environment_response.error.CopyFrom(
                    status_pb2.Status(message=str(e)))

            yield environment_response
Ejemplo n.º 9
0
 def test_empty_spec(self):
     self.assertDictEqual({},
                          dm_env_utils.dm_env_spec(
                              spec_manager.SpecManager({})))
Ejemplo n.º 10
0
    def __init__(
        self,
        connection: dm_env_rpc_connection.Connection,
        specs: dm_env_rpc_pb2.ActionObservationSpecs,
        requested_observations: Optional[Sequence[str]] = None,
        nested_tensors: bool = True,
        extensions: Optional[Mapping[str,
                                     Any]] = immutabledict.immutabledict()):
        """Initializes the environment with the provided dm_env_rpc connection.

    Args:
      connection: An instance of Connection already connected to a dm_env_rpc
        server and after a successful JoinWorldRequest has been sent.
      specs: A dm_env_rpc ActionObservationSpecs message for the environment.
      requested_observations: List of observation names to be requested from the
        environment when step is called. If None is specified then all
        observations will be requested.
      nested_tensors: Boolean to determine whether to flatten/unflatten tensors.
      extensions: Optional mapping of extension instances to DmEnvAdaptor
        attributes. Raises ValueError if attribute already exists.
    """
        self._dm_env_rpc_specs = specs
        self._action_specs = spec_manager.SpecManager(specs.actions)
        self._observation_specs = spec_manager.SpecManager(specs.observations)
        self._connection = connection
        self._last_state = dm_env_rpc_pb2.EnvironmentStateType.TERMINATED
        self._nested_tensors = nested_tensors

        if requested_observations is None:
            requested_observations = self._observation_specs.names()
            self._is_reward_requested = False
            self._is_discount_requested = False
        else:
            self._is_reward_requested = DEFAULT_REWARD_KEY in requested_observations
            self._is_discount_requested = (DEFAULT_DISCOUNT_KEY
                                           in requested_observations)

        requested_observations = set(requested_observations)

        self._default_reward_spec = None
        self._default_discount_spec = None
        if DEFAULT_REWARD_KEY in self._observation_specs.names():
            self._default_reward_spec = dm_env_utils.tensor_spec_to_dm_env_spec(
                self._observation_specs.name_to_spec(DEFAULT_REWARD_KEY))
            requested_observations.add(DEFAULT_REWARD_KEY)
        if DEFAULT_DISCOUNT_KEY in self._observation_specs.names():
            self._default_discount_spec = (
                dm_env_utils.tensor_spec_to_dm_env_spec(
                    self._observation_specs.name_to_spec(DEFAULT_DISCOUNT_KEY))
            )
            requested_observations.add(DEFAULT_DISCOUNT_KEY)

        unsupported_observations = requested_observations.difference(
            self._observation_specs.names())
        if unsupported_observations:
            raise ValueError('Unsupported observations requested: {}'.format(
                unsupported_observations))
        self._requested_observation_uids = [
            self._observation_specs.name_to_uid(name)
            for name in requested_observations
        ]

        # Not strictly necessary but it makes the unit tests deterministic.
        self._requested_observation_uids.sort()

        if extensions is not None:
            for extension_name, extension in extensions.items():
                if hasattr(self, extension_name):
                    raise ValueError(
                        f'DmEnvAdaptor already has attribute "{extension_name}"!'
                    )
                setattr(self, extension_name, extension)
Ejemplo n.º 11
0
 def setUp(self):
     super(SpecManagerTests, self).setUp()
     self._spec_manager = spec_manager.SpecManager(_EXAMPLE_SPECS)