def observe( self, action: Dict[str, types.NestedArray], next_timestep: dm_env.TimeStep, next_extras: Dict[str, types.NestedArray] = {}, ) -> None: for agent, observation_spec in self._spec.items(): if agent in action.keys(): _validate_spec(observation_spec.actions, action[agent]) if agent in next_timestep.reward.keys(): _validate_spec(observation_spec.rewards, next_timestep.reward[agent]) if agent in next_timestep.discount.keys(): _validate_spec(observation_spec.discounts, next_timestep.discount[agent]) if next_timestep.observation and agent in next_timestep.observation.keys( ): _validate_spec(observation_spec.observations, next_timestep.observation[agent]) if next_extras: _validate_spec(next_extras)
def step( self, actions: Dict[str, Union[float, int, types.NestedArray]] ) -> dm_env.TimeStep: # Return a reset timestep if we haven't touched the environment yet. if not self._step: return self.reset() for agent, action in actions.items(): _validate_spec(self._specs[agent].actions, action) observation = { agent: self._generate_fake_observation() for agent in self.agents } reward = {agent: self._generate_fake_reward() for agent in self.agents} discount = { agent: self._generate_fake_discount() for agent in self.agents } if self._episode_length and (self._step == self._episode_length): self._step = 0 # We can't use dm_env.termination directly because then the discount # wouldn't necessarily conform to the spec (if eg. we want float32). return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, observation) else: self._step += 1 return dm_env.transition(reward=reward, observation=observation, discount=discount)
def step(self, action: Union[float, int, types.NestedArray]) -> dm_env.TimeStep: # Return a reset timestep if we haven't touched the environment yet. if not self._step: return self.reset() _validate_spec(self._spec.actions, action) observation = self._generate_fake_observation() reward = self._generate_fake_reward() discount = self._generate_fake_discount() self.agent_step_counter += 1 if self._episode_length and (self._step == self._episode_length): # Only reset step once all all agents have taken their turn. if self.agent_step_counter == len(self.agents): self._step = 0 self.agent_step_counter = 0 # We can't use dm_env.termination directly because then the discount # wouldn't necessarily conform to the spec (if eg. we want float32). return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, observation) else: # Only update step counter once all agents have taken their turn. if self.agent_step_counter == len(self.agents): self._step += 1 self.agent_step_counter = 0 return dm_env.transition(reward=reward, observation=observation, discount=discount)
def observe_first( self, timestep: dm_env.TimeStep, extras: Dict[str, types.NestedArray] = {}, ) -> None: for agent, observation_spec in self._specs.items(): _validate_spec( observation_spec.observations, timestep.observation[agent], ) if extras: _validate_spec(extras)
def agent_observe( self, agent: str, action: Union[float, int, types.NestedArray], next_timestep: dm_env.TimeStep, ) -> None: observation_spec = self._spec[agent] _validate_spec(observation_spec.actions, action) _validate_spec(observation_spec.rewards, next_timestep.reward) _validate_spec(observation_spec.discounts, next_timestep.discount)
def agent_observe_first(self, agent: str, timestep: dm_env.TimeStep) -> None: _validate_spec(self._spec[agent].observations, timestep.observation)