Example #1
0
def test_new_env_step(sample_data):
    s = EnvStep(**sample_data)
    assert s.env_spec is sample_data['env_spec']
    assert s.observation is sample_data['observation']
    assert s.action is sample_data['action']
    assert s.reward is sample_data['reward']
    assert s.step_type is sample_data['step_type']
    assert s.env_info is sample_data['env_info']
    del s

    obs_space = akro.Box(low=-1, high=10, shape=(4, 3, 2), dtype=np.float32)
    act_space = akro.Box(low=-1, high=10, shape=(4, 2), dtype=np.float32)
    env_spec = EnvSpec(obs_space, act_space)
    sample_data['env_spec'] = env_spec
    obs_space = akro.Box(low=-1000,
                         high=1000,
                         shape=(4, 3, 2),
                         dtype=np.float32)
    act_space = akro.Box(low=-1000, high=1000, shape=(4, 2), dtype=np.float32)
    sample_data['observation'] = obs_space.sample()
    sample_data['action'] = act_space.sample()
    s = EnvStep(**sample_data)

    assert s.observation is sample_data['observation']
    assert s.action is sample_data['action']
Example #2
0
    def step(self, action):
        """Environment step for the active task env.

        Args:
            action (np.ndarray): Action performed by the agent in the
                environment.

        Returns:
            EnvStep: The environment step resulting from the action.

        """
        es = self._env.step(action)
        obs = es.observation
        oh_obs = self._obs_with_one_hot(obs)

        env_info = es.env_info

        env_info['task_id'] = self._task_index

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=es.reward,
                       observation=oh_obs,
                       env_info=env_info,
                       step_type=es.step_type)
Example #3
0
    def step(self, action):
        """Call step on wrapped env.

        Args:
            action (np.ndarray): An action provided by the agent.

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.

        """
        es = self._env.step(action)
        next_obs = es.observation
        next_obs = np.concatenate([
            next_obs, action, [es.reward], [es.step_type == StepType.TERMINAL]
        ])

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=es.reward,
                       observation=next_obs,
                       env_info=es.env_info,
                       step_type=es.step_type)
Example #4
0
def test_step_type_property_env_step(sample_data):
    sample_data['step_type'] = StepType.FIRST
    s = EnvStep(**sample_data)
    assert s.first

    sample_data['step_type'] = StepType.MID
    s = EnvStep(**sample_data)
    assert s.mid

    sample_data['step_type'] = StepType.TERMINAL
    s = EnvStep(**sample_data)
    assert s.terminal and s.last

    sample_data['step_type'] = StepType.TIMEOUT
    s = EnvStep(**sample_data)
    assert s.timeout and s.last
Example #5
0
    def step(self, action):
        """Step the active task env.

        Args:
            action (object): object to be passed in Environment.reset(action)

        Returns:
            EnvStep: The environment step resulting from the action.

        """
        es = self._env.step(action)

        if self._mode == 'add-onehot':
            obs = np.concatenate([es.observation, self._active_task_one_hot()])
        elif self._mode == 'del-onehot':
            obs = es.observation[:-self._num_tasks]
        else:  # self._mode == 'vanilla'
            obs = es.observation

        env_info = es.env_info
        if 'task_id' not in es.env_info:
            env_info['task_id'] = self._active_task_index
        if self._env_names is not None:
            env_info['task_name'] = self._env_names[self._active_task_index]

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=es.reward,
                       observation=obs,
                       env_info=env_info,
                       step_type=es.step_type)
Example #6
0
    def step(self, action):
        """Call step on wrapped env.

        Args:
            action (np.ndarray): An action provided by the agent.

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.

        """
        if self._step_cnt is None:
            raise RuntimeError('reset() must be called before step()!')

        observation, reward, done, info = self._env.step(action)

        if self._visualize:
            self._env.render(mode='human')

        reward = float(reward) if not isinstance(reward, float) else reward
        self._step_cnt += 1

        step_type = StepType.get_step_type(
            step_cnt=self._step_cnt,
            max_episode_length=self._max_episode_length,
            done=done)

        # gym envs that are wrapped in TimeLimit wrapper modify
        # the done/termination signal to be true whenever a time
        # limit expiration occurs. The following statement sets
        # the done signal to be True only if caused by an
        # environment termination, and not a time limit
        # termination. The time limit termination signal
        # will be saved inside env_infos as
        # 'GymEnv.TimeLimitTerminated'
        if 'TimeLimit.truncated' in info or step_type == StepType.TIMEOUT:
            info['GymEnv.TimeLimitTerminated'] = True
            info['TimeLimit.truncated'] = info.get('TimeLimit.truncated', True)
            step_type = StepType.TIMEOUT
        else:
            info['TimeLimit.truncated'] = False
            info['GymEnv.TimeLimitTerminated'] = False

        if step_type in (StepType.TERMINAL, StepType.TIMEOUT):
            self._step_cnt = None

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=reward,
                       observation=observation,
                       env_info=info,
                       step_type=step_type)
Example #7
0
def test_from_env_step_time_step(sample_data):
    agent_info = sample_data['agent_info']
    last_observation = sample_data['observation']
    observation = sample_data['next_observation']
    time_step = TimeStep(**sample_data)
    del sample_data['agent_info']
    del sample_data['next_observation']
    sample_data['observation'] = observation
    env_step = EnvStep(**sample_data)
    time_step_new = TimeStep.from_env_step(env_step=env_step,
                                           last_observation=last_observation,
                                           agent_info=agent_info)
    assert time_step == time_step_new
Example #8
0
    def step(self, action):
        """Steps the environment with the action and returns a `EnvStep`.

        Args:
            action (object): input action

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.
        """
        if self._step_cnt is None:
            raise RuntimeError('reset() must be called before step()!')

        dm_time_step = self._env.step(action)
        if self._viewer:
            self._viewer.render()

        observation = flatten_observation(
            dm_time_step.observation)['observations']

        self._step_cnt += 1

        # Determine step type
        step_type = None
        if dm_time_step.step_type == dm_StepType.MID:
            if self._step_cnt >= self._max_episode_length:
                step_type = StepType.TIMEOUT
            else:
                step_type = StepType.MID
        elif dm_time_step.step_type == dm_StepType.LAST:
            step_type = StepType.TERMINAL

        if step_type in (StepType.TERMINAL, StepType.TIMEOUT):
            self._step_cnt = None

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=dm_time_step.reward,
                       observation=observation,
                       env_info=dm_time_step.observation,
                       step_type=step_type)
Example #9
0
    def step(self, action):
        """Call step on wrapped env.

        Args:
            action (np.ndarray): An action provided by the agent.

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.

        """
        if isinstance(self.action_space, akro.Box):
            # rescale the action when the bounds are not inf
            lb, ub = self.action_space.low, self.action_space.high
            if np.all(lb != -np.inf) and np.all(ub != -np.inf):
                scaled_action = lb + (action + self._expected_action_scale) * (
                    0.5 * (ub - lb) / self._expected_action_scale)
                scaled_action = np.clip(scaled_action, lb, ub)
            else:
                scaled_action = action
        else:
            scaled_action = action

        es = self._env.step(scaled_action)
        next_obs = es.observation
        reward = es.reward

        if self._normalize_obs:
            next_obs = self._apply_normalize_obs(next_obs)
        if self._normalize_reward:
            reward = self._apply_normalize_reward(reward)

        return EnvStep(env_spec=es.env_spec,
                       action=es.action,
                       reward=reward * self._scale_reward,
                       observation=next_obs,
                       env_info=es.env_info,
                       step_type=es.step_type)
Example #10
0
    def step(self, action):
        """Call step on wrapped env.

        Args:
            action (np.ndarray): An action provided by the agent.

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.
            RuntimeError: if underlying environment outputs inconsistent
                env_info keys.

        """
        if self._step_cnt is None:
            raise RuntimeError('reset() must be called before step()!')

        observation, reward, done, info = self._env.step(action)

        if self._visualize:
            self._env.render(mode='human')

        reward = float(reward) if not isinstance(reward, float) else reward
        self._step_cnt += 1

        step_type = StepType.get_step_type(
            step_cnt=self._step_cnt,
            max_episode_length=self._max_episode_length,
            done=done)

        # gym envs that are wrapped in TimeLimit wrapper modify
        # the done/termination signal to be true whenever a time
        # limit expiration occurs. The following statement sets
        # the done signal to be True only if caused by an
        # environment termination, and not a time limit
        # termination. The time limit termination signal
        # will be saved inside env_infos as
        # 'GymEnv.TimeLimitTerminated'
        if 'TimeLimit.truncated' in info or step_type == StepType.TIMEOUT:
            info['GymEnv.TimeLimitTerminated'] = True
            info['TimeLimit.truncated'] = info.get('TimeLimit.truncated', True)
            step_type = StepType.TIMEOUT
        else:
            info['TimeLimit.truncated'] = False
            info['GymEnv.TimeLimitTerminated'] = False

        if step_type in (StepType.TERMINAL, StepType.TIMEOUT):
            self._step_cnt = None

        # check that env_infos are consistent
        if not self._env_info:
            self._env_info = {k: type(info[k]) for k in info}
        elif self._env_info.keys() != info.keys():
            raise RuntimeError('GymEnv outputs inconsistent env_info keys.')
        if not self.spec.observation_space.contains(observation):
            # Discrete actions can be either in the space normally, or one-hot
            # encoded.
            if self.spec.observation_space.flat_dim != np.prod(
                    observation.shape):
                raise RuntimeError('GymEnv observation shape does not '
                                   'conform to its observation_space')

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=reward,
                       observation=observation,
                       env_info=info,
                       step_type=step_type)
Example #11
0
    def step(self, action):
        """Step the environment.

        Args:
            action (np.ndarray): An action provided by the agent.

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment
            has been
                constructed and `reset()` has not been called.

        """
        if self._step_cnt is None:
            raise RuntimeError('reset() must be called before step()!')

        # enforce action space
        a = action.copy()  # NOTE: we MUST copy the action before modifying it
        a = np.clip(a, self.action_space.low, self.action_space.high)

        self._point = np.clip(self._point + a, -self._arena_size,
                              self._arena_size)
        if self._visualize:
            print(self.render('ascii'))

        dist = np.linalg.norm(self._point - self._goal)
        succ = dist < np.linalg.norm(self.action_space.low)

        # dense reward
        reward = -dist
        # done bonus
        if succ:
            reward += self._done_bonus
        # Type conversion
        if not isinstance(reward, float):
            reward = float(reward)

        # sometimes we don't want to terminate
        done = succ and not self._never_done

        obs = np.concatenate([self._point, (dist, )])

        self._step_cnt += 1

        step_type = StepType.get_step_type(
            step_cnt=self._step_cnt,
            max_episode_length=self._max_episode_length,
            done=done)

        if step_type in (StepType.TERMINAL, StepType.TIMEOUT):
            self._step_cnt = None

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=reward,
                       observation=obs,
                       env_info={
                           'task': self._task,
                           'success': succ
                       },
                       step_type=step_type)
Example #12
0
    def step(self, action):
        """Steps the environment.

        action map:
        0: left
        1: down
        2: right
        3: up

        Args:
            action (int): an int encoding the action

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.
            NotImplementedError: if a next step in self._desc does not match
                known state type.
        """
        if self._step_cnt is None:
            raise RuntimeError('reset() must be called before step()!')

        possible_next_states = self._get_possible_next_states(
            self._state, action)

        probs = [x[1] for x in possible_next_states]
        next_state_idx = np.random.choice(len(probs), p=probs)
        next_state = possible_next_states[next_state_idx][0]

        next_x = next_state // self._n_col
        next_y = next_state % self._n_col

        next_state_type = self._desc[next_x, next_y]
        if next_state_type == 'H':
            done = True
            reward = 0.0
        elif next_state_type in ['F', 'S']:
            done = False
            reward = 0.0
        elif next_state_type == 'G':
            done = True
            reward = 1.0
        else:
            raise NotImplementedError

        self._state = next_state

        self._step_cnt += 1
        step_type = StepType.get_step_type(
            step_cnt=self._step_cnt,
            max_episode_length=self._max_episode_length,
            done=done)

        if step_type in (StepType.TERMINAL, StepType.TIMEOUT):
            self._step_cnt = None

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=reward,
                       observation=next_state,
                       env_info={},
                       step_type=step_type)