Пример #1
0
 def step(self, action):
     time_step = self._env.step(action)
     if time_step.reward:
         self._total_reward += time_step.reward
     return Step(flatten_observation(time_step.observation),
                 time_step.reward, time_step.step_type == StepType.LAST,
                 **time_step.observation)
Пример #2
0
    def reset(self):
        """Reset the environment.

        Returns:
            Step: The first time step.
        """
        time_step = self._env.reset()
        return flatten_observation(time_step.observation)['observations']
Пример #3
0
    def _get_observation(self, timestep):
        """ This function will extract the observation from the output of the ``dmcenv.step``'s timestep.

        Returns:
            tuple: ``(observation, info)``
        """
        info = self._extract_obs_info(timestep.observation)
        if self._flat_observation:
            # return flatten_observation(timestep.observation)[FLAT_OBSERVATION_KEY], info

            return flatten_observation(timestep.observation,
                                       output_key=self._observation_key), info
        else:
            return timestep.observation, info
Пример #4
0
    def step(self, action):
        """Step the environment.

        Args:
            action (object): input action

        Returns:
            Step: The time step after applying this action.
        """
        time_step = self._env.step(action)
        return Step(
            flatten_observation(time_step.observation)['observations'],
            time_step.reward, time_step.step_type == StepType.LAST,
            **time_step.observation)
Пример #5
0
    def reset(self):
        """Resets the environment.

        Returns:
            numpy.ndarray: The first observation conforming to
                `observation_space`.
            dict: The episode-level information.
                Note that this is not part of `env_info` provided in `step()`.
                It contains information of he entire episode, which could be
                needed to determine the first action (e.g. in the case of
                goal-conditioned or MTRL.)

        """
        time_step = self._env.reset()
        first_obs = flatten_observation(time_step.observation)['observations']

        self._step_cnt = 0
        return first_obs, {}
Пример #6
0
 def reset(self):
   """Starts a new episode and returns the first `TimeStep`."""
   if self._stats_acc:
     self._stats_acc.clear_buffer()
   if self._task.perturb_enabled:
     if self._counter % self._task.perturb_period == 0:
       self._physics = self._task.update_physics()
     self._counter += 1
   timestep = super(LoggingEnv, self).reset()
   self._track(timestep)
   if self._flat_observation_:
     timestep = dm_env.TimeStep(
         step_type=timestep.step_type,
         reward=None,
         discount=None,
         observation=control.flatten_observation(
             timestep.observation)['observations'])
   return timestep
Пример #7
0
    def step(self, action):
        """Steps the environment with the action and returns a `EnvStep`.

        Args:
            action (object): input action

        Returns:
            EnvStep: The environment step resulting from the action.

        Raises:
            RuntimeError: if `step()` is called after the environment has been
                constructed and `reset()` has not been called.
        """
        if self._step_cnt is None:
            raise RuntimeError('reset() must be called before step()!')

        dm_time_step = self._env.step(action)
        if self._viewer:
            self._viewer.render()

        observation = flatten_observation(
            dm_time_step.observation)['observations']

        self._step_cnt += 1

        # Determine step type
        step_type = None
        if dm_time_step.step_type == dm_StepType.MID:
            if self._step_cnt >= self._max_episode_length:
                step_type = StepType.TIMEOUT
            else:
                step_type = StepType.MID
        elif dm_time_step.step_type == dm_StepType.LAST:
            step_type = StepType.TERMINAL

        if step_type in (StepType.TERMINAL, StepType.TIMEOUT):
            self._step_cnt = None

        return EnvStep(env_spec=self.spec,
                       action=action,
                       reward=dm_time_step.reward,
                       observation=observation,
                       env_info=dm_time_step.observation,
                       step_type=step_type)
Пример #8
0
 def step(self, action):
   """Updates the environment using the action and returns a `TimeStep`."""
   do_track = not self._reset_next_step
   timestep = super(LoggingEnv, self).step(action)
   if do_track:
     self._track(timestep)
   if timestep.last():
     self._ep_counter += 1
     if self._ep_counter % self._log_every == 0:
       self.write_logs()
   # Only flatten observation if we're not forwarding one from a reset(),
   # as it will already be flattened.
   if self._flat_observation_ and not timestep.first():
     timestep = dm_env.TimeStep(
         step_type=timestep.step_type,
         reward=timestep.reward,
         discount=timestep.discount,
         observation=control.flatten_observation(
             timestep.observation)['observations'])
   return timestep
Пример #9
0
    def _get_observation_spec(self):
        """ This function will extract the ``observation_spec`` of the environment if that is specified explicitly.
        Otherwise, it will first extract the ``info`` key from the observation dict, if that exists, and then forms
        the shape of the observation dict.
        """
        try:
            return self.dmcenv.task.observation_spec(self.dmcenv.physics)
        except NotImplementedError:
            observation = self.dmcenv.task.get_observation(self.dmcenv.physics)
            self._extract_obs_info(observation)
            if self._flat_observation:
                # observation = flatten_observation(observation)
                # return _spec_from_observation(observation)[FLAT_OBSERVATION_KEY]

                observation = flatten_observation(
                    observation, output_key=self._observation_key)
            #     return _spec_from_observation(observation)
            # else:
            #     return _spec_from_observation(observation)
            specs = _spec_from_observation(observation)
            return specs
Пример #10
0
 def reset(self):
     ts = self._dmenv.reset()
     obs = flatten_observation(ts.observation)
     return obs[FLAT_OBSERVATION_KEY]
Пример #11
0
 def step(self, action):
     ts = self._dmenv.step(action)
     obs = flatten_observation(ts.observation)[FLAT_OBSERVATION_KEY]
     reward = ts.reward
     done = ts.step_type.last()
     return obs, reward, done, {}
Пример #12
0
 def observation_space(self):
     obs = flatten_observation(
         self._dmenv.task.get_observation(self._dmenv.physics))[FLAT_OBSERVATION_KEY]
     return spaces.Box(-np.inf, np.inf, shape=obs.shape)
Пример #13
0
 def reset(self):
     time_step = self._env.reset()
     return flatten_observation(time_step.observation)['observations']
Пример #14
0
 def step(self, action):
     time_step = self._env.step(action)
     return Step(
         flatten_observation(time_step.observation)['observations'],
         time_step.reward, time_step.step_type == StepType.LAST,
         **time_step.observation)
Пример #15
0
 def reset(self):
     self._total_reward = 0
     time_step = self._env.reset()
     return flatten_observation(time_step.observation)