Esempio n. 1
0
    def _step(self, action):
        if self._done:
            return self.reset()

        if self._action_spec:
            nest.assert_same_structure(self._action_spec, action)

        self._num_steps += 1

        observation = self._get_observation()
        if self._num_steps < self._min_duration:
            self._done = False
        elif self._max_duration and self._num_steps >= self._max_duration:
            self._done = True
        else:
            self._done = self._rng.uniform() < self._episode_end_probability

        if self._batch_size:
            action = nest.map_structure(
                lambda t: np.concatenate([np.expand_dims(t, 0)] * self.
                                         _batch_size), action)

        if self._done:
            reward = self._reward_fn(ds.StepType.LAST, action, observation)
            self._check_reward_shape(reward)
            time_step = ds.termination(observation, action, reward,
                                       self._env_id)
            self._num_steps = 0
        else:
            reward = self._reward_fn(ds.StepType.MID, action, observation)
            self._check_reward_shape(reward)
            time_step = ds.transition(observation, action, reward,
                                      self._discount, self._env_id)

        return time_step
Esempio n. 2
0
    def _step(self, action):
        # Automatically reset the environments on step if they need to be reset.
        if self._auto_reset and self._done:
            return self.reset()

        observation, reward, self._done, self._info = self._gym_env.step(
            action)
        observation = self._to_spec_dtype_observation(observation)
        self._info = nest.map_structure(_as_array, self._info)

        if self._done:
            return ds.termination(
                observation,
                action,
                reward,
                self._reward_spec,
                self._env_id,
                env_info=self._info)
        else:
            return ds.transition(
                observation,
                action,
                reward,
                self._reward_spec,
                self._discount,
                self._env_id,
                env_info=self._info)