Ejemplo n.º 1
0
    def observe(
        self,
        action: Dict[str, types.NestedArray],
        next_timestep: dm_env.TimeStep,
        next_extras: Dict[str, types.NestedArray] = {},
    ) -> None:

        for agent, observation_spec in self._spec.items():
            if agent in action.keys():
                _validate_spec(observation_spec.actions, action[agent])

            if agent in next_timestep.reward.keys():
                _validate_spec(observation_spec.rewards,
                               next_timestep.reward[agent])

            if agent in next_timestep.discount.keys():
                _validate_spec(observation_spec.discounts,
                               next_timestep.discount[agent])

            if next_timestep.observation and agent in next_timestep.observation.keys(
            ):
                _validate_spec(observation_spec.observations,
                               next_timestep.observation[agent])
        if next_extras:
            _validate_spec(next_extras)
Ejemplo n.º 2
0
    def step(
        self, actions: Dict[str, Union[float, int, types.NestedArray]]
    ) -> dm_env.TimeStep:

        # Return a reset timestep if we haven't touched the environment yet.
        if not self._step:
            return self.reset()

        for agent, action in actions.items():
            _validate_spec(self._specs[agent].actions, action)

        observation = {
            agent: self._generate_fake_observation()
            for agent in self.agents
        }
        reward = {agent: self._generate_fake_reward() for agent in self.agents}
        discount = {
            agent: self._generate_fake_discount()
            for agent in self.agents
        }

        if self._episode_length and (self._step == self._episode_length):
            self._step = 0
            # We can't use dm_env.termination directly because then the discount
            # wouldn't necessarily conform to the spec (if eg. we want float32).
            return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount,
                                   observation)
        else:
            self._step += 1
            return dm_env.transition(reward=reward,
                                     observation=observation,
                                     discount=discount)
Ejemplo n.º 3
0
    def step(self, action: Union[float, int,
                                 types.NestedArray]) -> dm_env.TimeStep:
        # Return a reset timestep if we haven't touched the environment yet.
        if not self._step:
            return self.reset()

        _validate_spec(self._spec.actions, action)

        observation = self._generate_fake_observation()
        reward = self._generate_fake_reward()
        discount = self._generate_fake_discount()

        self.agent_step_counter += 1

        if self._episode_length and (self._step == self._episode_length):
            # Only reset step once all all agents have taken their turn.
            if self.agent_step_counter == len(self.agents):
                self._step = 0
                self.agent_step_counter = 0

            # We can't use dm_env.termination directly because then the discount
            # wouldn't necessarily conform to the spec (if eg. we want float32).
            return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount,
                                   observation)
        else:
            # Only update step counter once all agents have taken their turn.
            if self.agent_step_counter == len(self.agents):
                self._step += 1
                self.agent_step_counter = 0

            return dm_env.transition(reward=reward,
                                     observation=observation,
                                     discount=discount)
Ejemplo n.º 4
0
 def observe_first(
     self,
     timestep: dm_env.TimeStep,
     extras: Dict[str, types.NestedArray] = {},
 ) -> None:
     for agent, observation_spec in self._specs.items():
         _validate_spec(
             observation_spec.observations,
             timestep.observation[agent],
         )
     if extras:
         _validate_spec(extras)
Ejemplo n.º 5
0
 def agent_observe(
     self,
     agent: str,
     action: Union[float, int, types.NestedArray],
     next_timestep: dm_env.TimeStep,
 ) -> None:
     observation_spec = self._spec[agent]
     _validate_spec(observation_spec.actions, action)
     _validate_spec(observation_spec.rewards, next_timestep.reward)
     _validate_spec(observation_spec.discounts, next_timestep.discount)
Ejemplo n.º 6
0
 def agent_observe_first(self, agent: str,
                         timestep: dm_env.TimeStep) -> None:
     _validate_spec(self._spec[agent].observations, timestep.observation)