def test_get_step_type(): step_type = StepType.get_step_type(step_cnt=1, max_episode_length=5, done=False) assert step_type == StepType.FIRST step_type = StepType.get_step_type(step_cnt=2, max_episode_length=5, done=False) assert step_type == StepType.MID step_type = StepType.get_step_type(step_cnt=2, max_episode_length=None, done=False) assert step_type == StepType.MID step_type = StepType.get_step_type(step_cnt=5, max_episode_length=5, done=False) assert step_type == StepType.TIMEOUT step_type = StepType.get_step_type(step_cnt=5, max_episode_length=5, done=True) assert step_type == StepType.TIMEOUT step_type = StepType.get_step_type(step_cnt=1, max_episode_length=5, done=True) assert step_type == StepType.TERMINAL with pytest.raises(ValueError): step_type = StepType.get_step_type(step_cnt=0, max_episode_length=5, done=False)
def step(self, action): """Call step on wrapped env. Args: action (np.ndarray): An action provided by the agent. Returns: EnvStep: The environment step resulting from the action. Raises: RuntimeError: if `step()` is called after the environment has been constructed and `reset()` has not been called. """ if self._step_cnt is None: raise RuntimeError('reset() must be called before step()!') observation, reward, done, info = self._env.step(action) if self._visualize: self._env.render(mode='human') reward = float(reward) if not isinstance(reward, float) else reward self._step_cnt += 1 step_type = StepType.get_step_type( step_cnt=self._step_cnt, max_episode_length=self._max_episode_length, done=done) # gym envs that are wrapped in TimeLimit wrapper modify # the done/termination signal to be true whenever a time # limit expiration occurs. The following statement sets # the done signal to be True only if caused by an # environment termination, and not a time limit # termination. The time limit termination signal # will be saved inside env_infos as # 'GymEnv.TimeLimitTerminated' if 'TimeLimit.truncated' in info or step_type == StepType.TIMEOUT: info['GymEnv.TimeLimitTerminated'] = True info['TimeLimit.truncated'] = info.get('TimeLimit.truncated', True) step_type = StepType.TIMEOUT else: info['TimeLimit.truncated'] = False info['GymEnv.TimeLimitTerminated'] = False if step_type in (StepType.TERMINAL, StepType.TIMEOUT): self._step_cnt = None return EnvStep(env_spec=self.spec, action=action, reward=reward, observation=observation, env_info=info, step_type=step_type)
def step(self, action): """Call step on wrapped env. Args: action (np.ndarray): An action provided by the agent. Returns: EnvStep: The environment step resulting from the action. Raises: RuntimeError: if `step()` is called after the environment has been constructed and `reset()` has not been called. RuntimeError: if underlying environment outputs inconsistent env_info keys. """ if self._step_cnt is None: raise RuntimeError('reset() must be called before step()!') observation, reward, done, info = self._env.step(action) if self._visualize: self._env.render(mode='human') reward = float(reward) if not isinstance(reward, float) else reward self._step_cnt += 1 step_type = StepType.get_step_type( step_cnt=self._step_cnt, max_episode_length=self._max_episode_length, done=done) # gym envs that are wrapped in TimeLimit wrapper modify # the done/termination signal to be true whenever a time # limit expiration occurs. The following statement sets # the done signal to be True only if caused by an # environment termination, and not a time limit # termination. The time limit termination signal # will be saved inside env_infos as # 'GymEnv.TimeLimitTerminated' if 'TimeLimit.truncated' in info or step_type == StepType.TIMEOUT: info['GymEnv.TimeLimitTerminated'] = True info['TimeLimit.truncated'] = info.get('TimeLimit.truncated', True) step_type = StepType.TIMEOUT else: info['TimeLimit.truncated'] = False info['GymEnv.TimeLimitTerminated'] = False if step_type in (StepType.TERMINAL, StepType.TIMEOUT): self._step_cnt = None # check that env_infos are consistent if not self._env_info: self._env_info = {k: type(info[k]) for k in info} elif self._env_info.keys() != info.keys(): raise RuntimeError('GymEnv outputs inconsistent env_info keys.') if not self.spec.observation_space.contains(observation): # Discrete actions can be either in the space normally, or one-hot # encoded. if self.spec.observation_space.flat_dim != np.prod( observation.shape): raise RuntimeError('GymEnv observation shape does not ' 'conform to its observation_space') return EnvStep(env_spec=self.spec, action=action, reward=reward, observation=observation, env_info=info, step_type=step_type)
def step(self, action): """Step the environment. Args: action (np.ndarray): An action provided by the agent. Returns: EnvStep: The environment step resulting from the action. Raises: RuntimeError: if `step()` is called after the environment has been constructed and `reset()` has not been called. """ if self._step_cnt is None: raise RuntimeError('reset() must be called before step()!') # enforce action space a = action.copy() # NOTE: we MUST copy the action before modifying it a = np.clip(a, self.action_space.low, self.action_space.high) self._point = np.clip(self._point + a, -self._arena_size, self._arena_size) if self._visualize: print(self.render('ascii')) dist = np.linalg.norm(self._point - self._goal) succ = dist < np.linalg.norm(self.action_space.low) # dense reward reward = -dist # done bonus if succ: reward += self._done_bonus # Type conversion if not isinstance(reward, float): reward = float(reward) # sometimes we don't want to terminate done = succ and not self._never_done obs = np.concatenate([self._point, (dist, )]) self._step_cnt += 1 step_type = StepType.get_step_type( step_cnt=self._step_cnt, max_episode_length=self._max_episode_length, done=done) if step_type in (StepType.TERMINAL, StepType.TIMEOUT): self._step_cnt = None return EnvStep(env_spec=self.spec, action=action, reward=reward, observation=obs, env_info={ 'task': self._task, 'success': succ }, step_type=step_type)
def step(self, action): """Steps the environment. action map: 0: left 1: down 2: right 3: up Args: action (int): an int encoding the action Returns: EnvStep: The environment step resulting from the action. Raises: RuntimeError: if `step()` is called after the environment has been constructed and `reset()` has not been called. NotImplementedError: if a next step in self._desc does not match known state type. """ if self._step_cnt is None: raise RuntimeError('reset() must be called before step()!') possible_next_states = self._get_possible_next_states( self._state, action) probs = [x[1] for x in possible_next_states] next_state_idx = np.random.choice(len(probs), p=probs) next_state = possible_next_states[next_state_idx][0] next_x = next_state // self._n_col next_y = next_state % self._n_col next_state_type = self._desc[next_x, next_y] if next_state_type == 'H': done = True reward = 0.0 elif next_state_type in ['F', 'S']: done = False reward = 0.0 elif next_state_type == 'G': done = True reward = 1.0 else: raise NotImplementedError self._state = next_state self._step_cnt += 1 step_type = StepType.get_step_type( step_cnt=self._step_cnt, max_episode_length=self._max_episode_length, done=done) if step_type in (StepType.TERMINAL, StepType.TIMEOUT): self._step_cnt = None return EnvStep(env_spec=self.spec, action=action, reward=reward, observation=next_state, env_info={}, step_type=step_type)