Пример #1
0
  def step(self, action):
    """Updates the environment using the action and returns a `TimeStep`."""

    if self._reset_next_step:
      return self.reset()

    self._task.before_step(action, self._physics)
    for _ in xrange(self._n_sub_steps):
      self._physics.step()
    self._task.after_step(self._physics)

    reward = self._task.get_reward(self._physics)
    observation = self._task.get_observation(self._physics)
    if self._flat_observation:
      observation = flatten_observation(observation)

    if self.physics.time() >= self._time_limit:
      discount = 1.0
    else:
      discount = self._task.get_termination(self._physics)

    if discount is None:
      return environment.TimeStep(
          environment.StepType.MID, reward, 1.0, observation)
    else:
      self._reset_next_step = True
      return environment.TimeStep(
          environment.StepType.LAST, reward, discount, observation)
Пример #2
0
  def step(self, action):
    """Updates the environment using the action and returns a `TimeStep`."""

    if self._reset_next_step:
      return self.reset()

    self._task.before_step(action, self._physics)
    for _ in range(self._n_sub_steps):
      self._physics.step()
    self._task.after_step(self._physics)

    reward = self._task.get_reward(self._physics)
    observation = self._task.get_observation(self._physics)
    if self.concat_task_params_to_obs:
        observation['obs_task_params'] = self.obs_task_params
    if self._flat_observation:
      observation = flatten_observation(observation)

    self._step_count += 1
    if self._step_count >= self._step_limit:
      discount = 1.0
    else:
      discount = self._task.get_termination(self._physics)

    episode_over = discount is not None

    if episode_over:
      self._reset_next_step = True
      return environment.TimeStep(
          environment.StepType.LAST, reward, discount, observation)
    else:
      return environment.TimeStep(
          environment.StepType.MID, reward, 1.0, observation)
Пример #3
0
    def step(self, action):
        """Updates the environment using the action and returns a `TimeStep`."""
        if self._reset_next_step:
            self._reset_next_step = False
            return self.reset()

        self._hooks.before_step(self._physics_proxy, action,
                                self._random_state)
        self._observation_updater.prepare_for_next_control_step()

        try:
            for i in range(self._n_sub_steps):
                self._hooks.before_substep(self._physics_proxy, action,
                                           self._random_state)
                self._physics.step()
                self._hooks.after_substep(self._physics_proxy,
                                          self._random_state)
                # The final observation update must happen after all the hooks in
                # `self._hooks.after_step` is called. Otherwise, if any of these hooks
                # modify the physics state then we might capture an observation that is
                # inconsistent with the final physics state.
                if i < self._n_sub_steps - 1:
                    self._observation_updater.update(self._physics_proxy,
                                                     self._random_state)
            physics_is_divergent = False
        except control.PhysicsError as e:
            if not self._raise_exception_on_physics_error:
                logging.warning(e)
                physics_is_divergent = True
            else:
                raise

        self._hooks.after_step(self._physics_proxy, self._random_state)
        self._observation_updater.update(self._physics_proxy,
                                         self._random_state)

        if not physics_is_divergent:
            reward = self._task.get_reward(self._physics_proxy)
            discount = self._task.get_discount(self._physics_proxy)
            terminating = (self._task.should_terminate_episode(
                self._physics_proxy)
                           or self._physics.time() >= self._time_limit)
        else:
            reward = 0.0
            discount = 0.0
            terminating = True

        obs = self._observation_updater.get_observation()

        if not terminating:
            return environment.TimeStep(environment.StepType.MID, reward,
                                        discount, obs)
        else:
            self._reset_next_step = True
            return environment.TimeStep(environment.StepType.LAST, reward,
                                        discount, obs)
Пример #4
0
    def step(self, action):
        if self._reset_next_step:
            return self.reset()

        obs, reward, done, _ = self._gym_env.step(action)
        observation = collections.OrderedDict()
        observation[control.FLAT_OBSERVATION_KEY] = obs

        if done:
            self._reset_next_step = True
            return environment.TimeStep(environment.StepType.LAST, reward, 0.0,
                                        observation)

        return environment.TimeStep(environment.StepType.MID, reward, 1.0,
                                    observation)
Пример #5
0
  def reset(self, task_params=None, obs_task_params=None):
    """Starts a new episode and returns the first `TimeStep`.
    task_params: a dict of task params to pass to the task
    obs_task_params: a flat numpy array of task parameters that you'd give to your policy
    """
    self._reset_next_step = False
    self._step_count = 0
    if task_params is None:
        self.task_params, self.obs_task_params = self.task_params_sampler.sample()
    else:
        self.task_params, self.obs_task_params = task_params, obs_task_params
    with self._physics.reset_context():
      self._task.initialize_episode(self._physics, self.task_params)

    observation = self._task.get_observation(self._physics)
    if self.concat_task_params_to_obs:
        observation['obs_task_params'] = self.obs_task_params
    if self._flat_observation:
      observation = flatten_observation(observation)
    
    return environment.TimeStep(
        step_type=environment.StepType.FIRST,
        reward=None,
        discount=None,
        observation=observation
    )
Пример #6
0
 def step_spec(self):
     if (self._task.get_reward_spec() is None
             or self._task.get_discount_spec() is None):
         raise NotImplementedError
     return environment.TimeStep(
         step_type=None,
         reward=self._task.get_reward_spec(),
         discount=self._task.get_discount_spec(),
         observation=self._observation_updater.observation_spec(),
     )
Пример #7
0
 def reset(self):
     self._reset_next_step = False
     obs = self._gym_env.reset()
     observation = collections.OrderedDict()
     observation[control.FLAT_OBSERVATION_KEY] = obs
     return environment.TimeStep(
         step_type=environment.StepType.FIRST,
         reward=None,
         discount=None,
         observation=observation,
     )
Пример #8
0
 def _reset_attempt(self):
   self._hooks.initialize_episode_mjcf(self._random_state)
   self._recompile_physics()
   with self._physics.reset_context():
     self._hooks.initialize_episode(self._physics_proxy, self._random_state)
   self._observation_updater.reset(self._physics_proxy, self._random_state)
   self._reset_next_step = False
   return environment.TimeStep(
       step_type=environment.StepType.FIRST,
       reward=None,
       discount=None,
       observation=self._observation_updater.get_observation())
Пример #9
0
 def step_spec(self):
   """DEPRECATED: please use `reward_spec` and `discount_spec` instead."""
   warnings.warn('`step_spec` is deprecated, please use `reward_spec` and '
                 '`discount_spec` instead.', DeprecationWarning)
   if (self._task.get_reward_spec() is None or
       self._task.get_discount_spec() is None):
     raise NotImplementedError
   return environment.TimeStep(
       step_type=None,
       reward=self._task.get_reward_spec(),
       discount=self._task.get_discount_spec(),
       observation=self._observation_updater.observation_spec(),
   )
Пример #10
0
    def reset(self):
        """Starts a new episode and returns the first `TimeStep`."""
        self._reset_next_step = False
        with self._physics.reset_context():
            self._task.initialize_episode(self._physics)

        observation = self._task.get_observation(self._physics)
        if self._flat_observation:
            observation = flatten_observation(observation)

        return environment.TimeStep(step_type=environment.StepType.FIRST,
                                    reward=None,
                                    discount=None,
                                    observation=observation)